#include "AnalysisVisitor.hxx"
#include "ConstantVisitor.hxx"
+#include "double.hxx"
namespace analysis
{
isConstant = true;
}
}
+ else if (name == L"size")
+ {
+ if (parent->getAnalyzer(sym)->analyze(*parent, lhs, e))
+ {
+ switch (lhs)
+ {
+ case 1: // a = size(x)
+ {
+ std::vector<Result> & res = parent->getLHSContainer();
+ double row;
+ res.front().getConstant().getDblValue(row);
+
+ double col;
+ res.back().getConstant().getDblValue(col);
+
+ types::Double* pIT = new types::Double(1, 2);
+ pIT->get()[0] = row;
+ pIT->get()[1] = col;
+ e.replace(new ast::DoubleExp(e.getLocation(), pIT));
+ isConstant = true;
+ break;
+ }
+ case 2: // [a, b] = size(x)
+ {
+ double val;
+ ast::exps_t * exps = new ast::exps_t();
+ exps->reserve(2);
+ std::vector<Result> & res = parent->getLHSContainer();
+ res.front().getConstant().getDblValue(val);
+ exps->push_back(new ast::DoubleExp(e.getLocation(), val));
+ res.back().getConstant().getDblValue(val);
+ exps->push_back(new ast::DoubleExp(e.getLocation(), val));
+ e.replace(new ast::ArrayListExp(e.getLocation(), *exps));
+ isConstant = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ else if (parent && args.size() == 2)
+ {
+ if (name == L"size")
+ {
+ if (parent->getAnalyzer(sym)->analyze(*parent, lhs, e))
+ {
+ //a = size(x, "dims") or a = size(x, dim)
+ double val;
+ parent->getResult().getConstant().getDblValue(val);
+ e.replace(new ast::DoubleExp(e.getLocation(), val));
+ isConstant = true;
+ }
+ }
}
}
}
}
const ast::exps_t args = e.getArgs();
- enum Kind
- {
- ROWS, COLS, ROWSTIMESCOLS, ROWSCOLS, ONE, BOTH, DUNNO
- } kind = DUNNO;
+ SizeCall::Kind kind = SizeCall::DUNNO;
+
const std::size_t size = args.size();
if (size == 0 || size >= 3)
{
case 1:
if (lhs == 1)
{
- kind = BOTH;
+ kind = SizeCall::BOTH;
}
else if (lhs == 2)
{
- kind = ROWSCOLS;
+ kind = SizeCall::R_C;
}
break;
case 2:
const std::wstring & arg2 = static_cast<ast::StringExp *>(second)->getValue();
if (arg2 == L"r")
{
- kind = ROWS;
+ kind = SizeCall::R;
}
else if (arg2 == L"c")
{
- kind = COLS;
+ kind = SizeCall::C;
}
else if (arg2 == L"*")
{
- kind = ROWSTIMESCOLS;
+ kind = SizeCall::RC;
}
else
{
const double arg2 = static_cast<ast::DoubleExp *>(second)->getValue();
if (arg2 == 1)
{
- kind = ROWS;
+ kind = SizeCall::R;
}
else if (arg2 == 2)
{
- kind = COLS;
+ kind = SizeCall::C;
}
else if (arg2 >= 3)
{
// TODO: we should handle hypermatrix
- kind = ONE;
+ kind = SizeCall::ONE;
}
else
{
switch (kind)
{
- case ROWS:
+ case SizeCall::R:
{
SymbolicDimension & rows = res.getType().rows;
Result & _res = e.getDecorator().setResult(type);
visitor.setResult(_res);
break;
}
- case COLS:
+ case SizeCall::C:
{
SymbolicDimension & cols = res.getType().cols;
Result & _res = e.getDecorator().setResult(type);
visitor.setResult(_res);
break;
}
- case ROWSTIMESCOLS:
+ case SizeCall::RC:
{
SymbolicDimension & rows = res.getType().rows;
SymbolicDimension & cols = res.getType().cols;
SymbolicDimension prod = rows * cols;
+
Result & _res = e.getDecorator().setResult(type);
_res.getConstant() = prod.getValue();
e.getDecorator().setCall(new SizeCall(SizeCall::RC));
visitor.setResult(_res);
break;
}
- case ROWSCOLS:
+ case SizeCall::R_C:
+ case SizeCall::BOTH:
{
+ if (kind == SizeCall::BOTH)
+ {
+ TIType _type(visitor.getGVN(), TIType::DOUBLE, 1, 2);
+ Result & _res = e.getDecorator().setResult(_type);
+ }
+
SymbolicDimension & rows = res.getType().rows;
SymbolicDimension & cols = res.getType().cols;
std::vector<Result> & mlhs = visitor.getLHSContainer();
mlhs.emplace_back(type);
mlhs.back().getConstant() = cols.getValue();
- e.getDecorator().setCall(new SizeCall(SizeCall::R_C));
+ e.getDecorator().setCall(new SizeCall(kind));
break;
}
- case ONE:
+ case SizeCall::ONE:
{
Result & _res = e.getDecorator().setResult(type);
_res.getConstant() = new types::Double(1);
visitor.setResult(_res);
break;
}
- case BOTH:
- {
- TIType _type(visitor.getGVN(), TIType::DOUBLE, 1, 2);
- Result & _res = e.getDecorator().setResult(_type);
- e.getDecorator().setCall(new SizeCall(SizeCall::BOTH));
- visitor.setResult(_res);
- break;
- }
default:
return false;
}
if (e.getRightExp().isCallExp()) // A = foo(...)
{
- if (e.getRightExp().isCallExp())
- {
- visit(static_cast<ast::CallExp &>(e.getRightExp()), /* LHS */ 1);
- }
- else
- {
- e.getRightExp().accept(*this);
- }
+ visit(static_cast<ast::CallExp &>(e.getRightExp()), /* LHS */ 1);
}
else // A = 1 + 2
{
if (e.getRightExp().isCallExp())
{
const ast::exps_t & exps = ale.getExps();
+
+ // apply the ConstantVisitor
+ cv.setLHS(exps.size());
+ e.getRightExp().accept(cv);
+
visit(static_cast<ast::CallExp &>(e.getRightExp()), /* LHS */ exps.size());
std::vector<Result>::iterator j = multipleLHS.begin();
for (const auto exp : exps)