use reduce version of function from analysis during macro execution
[scilab.git] / scilab / modules / ast / src / cpp / analysis / AnalysisVisitor.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
4  *
5  * Copyright (C) 2012 - 2016 - Scilab Enterprises
6  *
7  * This file is hereby licensed under the terms of the GNU GPL v2.0,
8  * pursuant to article 5.3.4 of the CeCILL v.2.1.
9  * This file was originally licensed under the terms of the CeCILL v2.1,
10  * and continues to be available under such terms.
11  * For more information, see the COPYING file which you should have received
12  * along with this program.
13  *
14  */
15
16 #include "symbol.hxx"
17
18 #include "AnalysisVisitor.hxx"
19 #include "analyzers/ArgnAnalyzer.hxx"
20 #include "analyzers/CeilAnalyzer.hxx"
21 #include "analyzers/DiagAnalyzer.hxx"
22 #include "analyzers/LengthAnalyzer.hxx"
23 #include "analyzers/MatrixAnalyzer.hxx"
24 #include "analyzers/MemInitAnalyzer.hxx"
25 #include "analyzers/SizeAnalyzer.hxx"
26 #include "analyzers/TypeAnalyzer.hxx"
27 #include "analyzers/TypeofAnalyzer.hxx"
28 #include "analyzers/InttypeAnalyzer.hxx"
29 #include "analyzers/IconvertAnalyzer.hxx"
30 #include "analyzers/IsrealAnalyzer.hxx"
31 #include "analyzers/IsscalarAnalyzer.hxx"
32 #include "analyzers/FindAnalyzer.hxx"
33
34 //#include "analyzers/SqrtAnalyzer.hxx"
35
36 namespace analysis
37 {
38 AnalysisVisitor::MapSymCall AnalysisVisitor::symscall = AnalysisVisitor::initCalls();//a=1:3;b=2;c=3;testAnalysis("repmat","a","b","c")
39
40 AnalysisVisitor::MapSymCall AnalysisVisitor::initCalls()
41 {
42     MapSymCall msc;
43
44     msc.emplace(L"zeros", std::shared_ptr<CallAnalyzer>(new ZerosAnalyzer()));
45     msc.emplace(L"ones", std::shared_ptr<CallAnalyzer>(new OnesAnalyzer()));
46     msc.emplace(L"rand", std::shared_ptr<CallAnalyzer>(new RandAnalyzer()));
47     msc.emplace(L"matrix", std::shared_ptr<CallAnalyzer>(new MatrixAnalyzer()));
48     msc.emplace(L"eye", std::shared_ptr<CallAnalyzer>(new EyeAnalyzer()));
49
50     std::shared_ptr<CallAnalyzer> ca(new CeilAnalyzer());
51     msc.emplace(L"ceil", ca);
52     msc.emplace(L"floor", ca);
53     msc.emplace(L"round", ca);
54     msc.emplace(L"fix", ca);
55     msc.emplace(L"int", ca);
56
57     //msc.emplace(L"sqrt", std::shared_ptr<CallAnalyzer>(new SqrtAnalyzer()));
58     msc.emplace(L"argn", std::shared_ptr<CallAnalyzer>(new ArgnAnalyzer()));
59     msc.emplace(L"size", std::shared_ptr<CallAnalyzer>(new SizeAnalyzer()));
60     msc.emplace(L"length", std::shared_ptr<CallAnalyzer>(new LengthAnalyzer()));
61     msc.emplace(L"diag", std::shared_ptr<CallAnalyzer>(new DiagAnalyzer()));
62     msc.emplace(L"type", std::shared_ptr<CallAnalyzer>(new TypeAnalyzer()));
63     msc.emplace(L"typeof", std::shared_ptr<CallAnalyzer>(new TypeofAnalyzer()));
64     msc.emplace(L"inttype", std::shared_ptr<CallAnalyzer>(new InttypeAnalyzer()));
65     msc.emplace(L"iconvert", std::shared_ptr<CallAnalyzer>(new IconvertAnalyzer()));
66     msc.emplace(L"isreal", std::shared_ptr<CallAnalyzer>(new IsrealAnalyzer()));
67     msc.emplace(L"isscalar", std::shared_ptr<CallAnalyzer>(new IsscalarAnalyzer()));
68     msc.emplace(L"find", std::shared_ptr<CallAnalyzer>(new FindAnalyzer()));
69
70     return msc;
71 }
72
73
74 AnalysisVisitor::AnalysisVisitor() : cv(*this), pv(std::wcerr, true, false), logger("/tmp/analysis.log")
75 {
76     start_chrono();
77 }
78
79 AnalysisVisitor::~AnalysisVisitor() { }
80
81 void AnalysisVisitor::reset()
82 {
83     _result = Result();
84     dm.reset();
85     multipleLHS.clear();
86     while (!loops.empty())
87     {
88         loops.pop();
89     }
90     start_chrono();
91 }
92
93 void AnalysisVisitor::print_info()
94 {
95     stop_chrono();
96
97     //std::wcout << getGVN() << std::endl << std::endl; function z=foo(x,y);z=argn(2);endfunction;jit("x=123;y=456;t=foo(x,y)")
98     std::wcerr << L"Analysis: " << *static_cast<Chrono *>(this) << std::endl;
99     //std::wcout << temp << std::endl;
100
101     std::wcerr << dm << std::endl;
102     std::wcerr << pmc << std::endl;
103
104     std::wcerr << std::endl;
105 }
106
107 logging::Logger & AnalysisVisitor::getLogger()
108 {
109     return logger;
110 }
111
112 bool AnalysisVisitor::asDouble(types::InternalType * pIT, double & out)
113 {
114     if (pIT && pIT->isDouble())
115     {
116         types::Double * pDbl = static_cast<types::Double *>(pIT);
117         if (!pDbl->isComplex() && pDbl->getSize() == 1)
118         {
119             out = pDbl->get()[0];
120             return true;
121         }
122     }
123
124     return false;
125 }
126
127 bool AnalysisVisitor::asDouble(ast::Exp & e, double & out)
128 {
129     if (e.isDoubleExp())
130     {
131         out = static_cast<ast::DoubleExp &>(e).getValue();
132         return true;
133     }
134     else if (e.isOpExp())
135     {
136         ast::OpExp & op = static_cast<ast::OpExp &>(e);
137         if (op.getOper() == ast::OpExp::unaryMinus)
138         {
139             if (op.getRight().isDoubleExp())
140             {
141                 out = -static_cast<ast::DoubleExp &>(op.getRight()).getValue();
142                 return true;
143             }
144         }
145         else if (op.getLeft().isDoubleExp() && op.getRight().isDoubleExp())
146         {
147             const double L = static_cast<ast::DoubleExp &>(op.getLeft()).getValue();
148             const double R = static_cast<ast::DoubleExp &>(op.getRight()).getValue();
149
150             switch (op.getOper())
151             {
152                 case ast::OpExp::minus:
153                     out = L - R;
154                     return true;
155                 case ast::OpExp::plus:
156                     out = L + R;
157                     return true;
158                 case ast::OpExp::times:
159                 case ast::OpExp::dottimes:
160                 case ast::OpExp::krontimes:
161                     out = L * R;
162                     return true;
163                 case ast::OpExp::rdivide:
164                 case ast::OpExp::dotrdivide:
165                 case ast::OpExp::kronrdivide:
166                     out = L / R;
167                     return true;
168                 case ast::OpExp::ldivide:
169                 case ast::OpExp::dotldivide:
170                 case ast::OpExp::kronldivide:
171                     out = R / L;
172                     return true;
173                 case ast::OpExp::power:
174                 case ast::OpExp::dotpower:
175                     out = std::pow(L, R);
176                     return true;
177                 default:
178                     return false;
179             }
180         }
181     }
182
183     return false;
184 }
185
186 bool AnalysisVisitor::isDoubleConstant(const ast::Exp & e)
187 {
188     if (e.isDoubleExp())
189     {
190         return true;
191     }
192     else if (e.isOpExp())
193     {
194         const ast::OpExp & oe = static_cast<const ast::OpExp &>(e);
195         if (!oe.isBooleanOp())
196         {
197             return isDoubleConstant(oe.getLeft()) && isDoubleConstant(oe.getRight());
198         }
199         return false;
200     }
201     else if (e.isMatrixExp())
202     {
203         const ast::MatrixExp & me = static_cast<const ast::MatrixExp &>(e);
204         const ast::exps_t & lines = me.getLines();
205         for (const auto line : lines)
206         {
207             const ast::exps_t & columns = static_cast<ast::MatrixLineExp *>(line)->getColumns();
208             for (const auto column : columns)
209             {
210                 if (column && !isDoubleConstant(*column))
211                 {
212                     return false;
213                 }
214             }
215         }
216         return true;
217     }
218     else if (e.isListExp())
219     {
220         const ast::ListExp & le = static_cast<const ast::ListExp &>(e);
221         return isDoubleConstant(le.getStart()) && isDoubleConstant(le.getStep()) && isDoubleConstant(le.getEnd());
222     }
223     else if (e.isSimpleVar())
224     {
225         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(e);
226         const symbol::Symbol & sym = var.getSymbol();
227         const std::wstring & name = sym.getName();
228         return name == L"%i" || name == L"%inf" || name == L"%nan" || name == L"%eps" || name == L"%pi" || name == L"%e";
229     }
230     else if (e.isCallExp())
231     {
232         const ast::CallExp & ce = static_cast<const ast::CallExp &>(e);
233         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(ce.getName());
234         const std::wstring & name = var.getSymbol().getName();
235
236         // TODO: check if 'ones' and 'zeros' are the expected functions
237         // ie: ones="abc"; ones(1) !!!
238         if (name == L"ones" || name == L"zeros")
239         {
240             const ast::exps_t args = ce.getArgs();
241             switch (args.size())
242             {
243                 case 0:
244                     return true;
245                 case 1:
246                     return isDoubleConstant(*args.front());
247                 case 2:
248                     return isDoubleConstant(*args.front()) && isDoubleConstant(**std::next(args.cbegin()));
249                 default:
250                     return false;
251             }
252         }
253     }
254
255     return false;
256 }
257
258 bool AnalysisVisitor::asDoubleMatrix(ast::Exp & e, types::Double *& data)
259 {
260     if (isDoubleConstant(e))
261     {
262         ast::ExecVisitor exec;
263         e.accept(exec);
264         types::InternalType * pIT = exec.getResult();
265         // TODO : handle complex case
266         if (pIT && pIT->isDouble() && !pIT->getAs<types::Double>()->isComplex())
267         {
268             pIT->IncreaseRef();
269             data = static_cast<types::Double *>(pIT);
270
271             return true;
272         }
273     }
274
275     return false;
276 }
277
278 void AnalysisVisitor::visitArguments(const std::wstring & name, const unsigned int lhs, const TIType & calltype, ast::CallExp & e, const ast::exps_t & args)
279 {
280     std::vector<Result> resargs;
281     std::vector<TIType> vargs;
282     vargs.reserve(args.size());
283     resargs.reserve(args.size());
284
285     const ast::SimpleVar & var = static_cast<ast::SimpleVar &>(e.getName());
286     const symbol::Symbol & sym = var.getSymbol();
287     argIndices.emplace(var, args.size(), 0);
288
289     for (auto arg : args)
290     {
291         argIndices.top().getIndex() += 1;
292         arg->accept(*this);
293         resargs.push_back(getResult());
294         vargs.push_back(getResult().getType());
295     }
296
297     argIndices.pop();
298     uint64_t functionId = 0;
299     std::vector<TIType> out = getDM().call(*this, lhs, sym, vargs, &e, functionId);
300     if (lhs > 1)
301     {
302         multipleLHS.clear();
303         multipleLHS.reserve(out.size());
304         for (const auto & type : out)
305         {
306             const int tempId = getDM().getTmpId(type, false);
307             multipleLHS.emplace_back(type, tempId);
308         }
309
310         auto i = args.begin();
311         for (const auto & resarg : resargs)
312         {
313             getDM().releaseTmp(resarg.getTempId(), *i);
314             ++i;
315         }
316     }
317     else if (lhs == 1)
318     {
319         int tempId = -1;
320         if (resargs.size() == 1)
321         {
322             const int id = resargs.back().getTempId();
323             if (id != -1 && Checkers::isElementWise(name) && out[0] == resargs.back().getType())
324             {
325                 tempId = id;
326             }
327         }
328         if (tempId == -1)
329         {
330             tempId = getDM().getTmpId(out[0], false);
331             auto i = args.begin();
332             for (const auto & resarg : resargs)
333             {
334                 getDM().releaseTmp(resarg.getTempId(), *i);
335                 ++i;
336             }
337         }
338
339         e.getDecorator().res = Result(out[0], tempId, functionId);
340         e.getDecorator().setCall(name, vargs);
341         setResult(e.getDecorator().res);
342     }
343 }
344
345 int AnalysisVisitor::getTmpIdForEWOp(const TIType & resT, const Result & LR, const Result & RR, ast::Exp * Lexp, ast::Exp * Rexp)
346 {
347     int tempId = -1;
348     if (resT.isknown() && resT.ismatrix())
349     {
350         if (LR.isTemp() || RR.isTemp())
351         {
352             const int Lid = LR.getTempId();
353             const int Rid = RR.getTempId();
354             const TIType & LT = LR.getType();
355             const TIType & RT = RR.getType();
356
357             if (LT.isscalar())
358             {
359                 if (RT.isscalar())
360                 {
361                     if (Lid == -1)
362                     {
363                         if (resT == LT)
364                         {
365                             tempId = Rid;
366                         }
367                         else
368                         {
369                             tempId = getDM().getTmpId(resT, false);
370                             getDM().releaseTmp(Rid, Rexp);
371                         }
372                     }
373                     else
374                     {
375                         if (resT == LT)
376                         {
377                             tempId = Lid;
378                             getDM().releaseTmp(Rid, Rexp);
379                         }
380                         else if (Rid != -1 && resT == RT)
381                         {
382                             tempId = Rid;
383                             getDM().releaseTmp(Lid, Lexp);
384                         }
385                         else
386                         {
387                             tempId = getDM().getTmpId(resT, false);
388                             getDM().releaseTmp(Lid, Lexp);
389                         }
390                     }
391                 }
392                 else
393                 {
394                     if (Rid == -1)
395                     {
396                         tempId = getDM().getTmpId(resT, false);
397                     }
398                     else
399                     {
400                         if (resT == RT)
401                         {
402                             tempId = Rid;
403                         }
404                         else if (Lid != -1 && resT == LT)
405                         {
406                             tempId = Lid;
407                             getDM().releaseTmp(Rid, Rexp);
408                         }
409                         else
410                         {
411                             tempId = getDM().getTmpId(resT, false);
412                             getDM().releaseTmp(Rid, Rexp);
413                         }
414                     }
415                     getDM().releaseTmp(Lid, Lexp);
416                 }
417             }
418             else
419             {
420                 if (RT.isscalar())
421                 {
422                     if (Lid == -1)
423                     {
424                         tempId = getDM().getTmpId(resT, false);
425                     }
426                     else
427                     {
428                         if (resT == LT)
429                         {
430                             tempId = Lid;
431                         }
432                         else if (Rid != -1 && resT == RT)
433                         {
434                             tempId = Rid;
435                             getDM().releaseTmp(Lid, Lexp);
436                         }
437                         else
438                         {
439                             tempId = getDM().getTmpId(resT, false);
440                             getDM().releaseTmp(Lid, Lexp);
441                         }
442                     }
443                     getDM().releaseTmp(Rid, Rexp);
444                 }
445                 else
446                 {
447                     if (Rid == -1)
448                     {
449                         if (resT == LT)
450                         {
451                             tempId = Lid;
452                         }
453                         else
454                         {
455                             tempId = getDM().getTmpId(resT, false);
456                             getDM().releaseTmp(Lid, Lexp);
457                         }
458                     }
459                     else
460                     {
461                         if (resT == RT)
462                         {
463                             tempId = Rid;
464                         }
465                         else if (Lid != -1 && resT == LT)
466                         {
467                             tempId = Lid;
468                             getDM().releaseTmp(Rid, Rexp);
469                         }
470                         else
471                         {
472                             tempId = getDM().getTmpId(resT, false);
473                             getDM().releaseTmp(Rid, Rexp);
474                         }
475                         getDM().releaseTmp(Lid, Lexp);
476                     }
477                 }
478             }
479         }
480         else
481         {
482             tempId = getDM().getTmpId(resT, false);
483         }
484     }
485
486     return tempId;
487 }
488 }