Analysis: add info about $ and fix bug
[scilab.git] / scilab / modules / ast / src / cpp / analysis / AnalysisVisitor.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include "symbol.hxx"
14
15 #include "AnalysisVisitor.hxx"
16 #include "analyzers/ArgnAnalyzer.hxx"
17 #include "analyzers/CeilAnalyzer.hxx"
18 #include "analyzers/DiagAnalyzer.hxx"
19 #include "analyzers/LengthAnalyzer.hxx"
20 #include "analyzers/MatrixAnalyzer.hxx"
21 #include "analyzers/MemInitAnalyzer.hxx"
22 #include "analyzers/SizeAnalyzer.hxx"
23 #include "analyzers/TypeAnalyzer.hxx"
24 #include "analyzers/TypeofAnalyzer.hxx"
25 #include "analyzers/InttypeAnalyzer.hxx"
26 #include "analyzers/IconvertAnalyzer.hxx"
27 #include "analyzers/IsrealAnalyzer.hxx"
28 #include "analyzers/IsscalarAnalyzer.hxx"
29 #include "analyzers/FindAnalyzer.hxx"
30
31 //#include "analyzers/SqrtAnalyzer.hxx"
32
33 namespace analysis
34 {
35 AnalysisVisitor::MapSymCall AnalysisVisitor::symscall = AnalysisVisitor::initCalls();//a=1:3;b=2;c=3;testAnalysis("repmat","a","b","c")
36
37 AnalysisVisitor::MapSymCall AnalysisVisitor::initCalls()
38 {
39     MapSymCall msc;
40
41     msc.emplace(L"zeros", std::shared_ptr<CallAnalyzer>(new ZerosAnalyzer()));
42     msc.emplace(L"ones", std::shared_ptr<CallAnalyzer>(new OnesAnalyzer()));
43     msc.emplace(L"rand", std::shared_ptr<CallAnalyzer>(new RandAnalyzer()));
44     msc.emplace(L"matrix", std::shared_ptr<CallAnalyzer>(new MatrixAnalyzer()));
45     msc.emplace(L"eye", std::shared_ptr<CallAnalyzer>(new EyeAnalyzer()));
46
47     std::shared_ptr<CallAnalyzer> ca(new CeilAnalyzer());
48     msc.emplace(L"ceil", ca);
49     msc.emplace(L"floor", ca);
50     msc.emplace(L"round", ca);
51     msc.emplace(L"fix", ca);
52     msc.emplace(L"int", ca);
53
54     //msc.emplace(L"sqrt", std::shared_ptr<CallAnalyzer>(new SqrtAnalyzer()));
55     msc.emplace(L"argn", std::shared_ptr<CallAnalyzer>(new ArgnAnalyzer()));
56     msc.emplace(L"size", std::shared_ptr<CallAnalyzer>(new SizeAnalyzer()));
57     msc.emplace(L"length", std::shared_ptr<CallAnalyzer>(new LengthAnalyzer()));
58     msc.emplace(L"diag", std::shared_ptr<CallAnalyzer>(new DiagAnalyzer()));
59     msc.emplace(L"type", std::shared_ptr<CallAnalyzer>(new TypeAnalyzer()));
60     msc.emplace(L"typeof", std::shared_ptr<CallAnalyzer>(new TypeofAnalyzer()));
61     msc.emplace(L"inttype", std::shared_ptr<CallAnalyzer>(new InttypeAnalyzer()));
62     msc.emplace(L"iconvert", std::shared_ptr<CallAnalyzer>(new IconvertAnalyzer()));
63     msc.emplace(L"isreal", std::shared_ptr<CallAnalyzer>(new IsrealAnalyzer()));
64     msc.emplace(L"isscalar", std::shared_ptr<CallAnalyzer>(new IsscalarAnalyzer()));
65     msc.emplace(L"find", std::shared_ptr<CallAnalyzer>(new FindAnalyzer()));
66
67     return msc;
68 }
69
70
71 AnalysisVisitor::AnalysisVisitor() : cv(*this), pv(std::wcerr, true, false), logger("/tmp/analysis.log")
72 {
73     start_chrono();
74 }
75
76 AnalysisVisitor::~AnalysisVisitor() { }
77
78 void AnalysisVisitor::reset()
79 {
80     _result = Result();
81     dm.reset();
82     multipleLHS.clear();
83     while (!loops.empty())
84     {
85         loops.pop();
86     }
87     start_chrono();
88 }
89
90 void AnalysisVisitor::print_info()
91 {
92     stop_chrono();
93
94     //std::wcout << getGVN() << std::endl << std::endl; function z=foo(x,y);z=argn(2);endfunction;jit("x=123;y=456;t=foo(x,y)")
95     std::wcerr << L"Analysis: " << *static_cast<Chrono *>(this) << std::endl;
96     //std::wcout << temp << std::endl;
97
98     std::wcerr << dm << std::endl;
99     std::wcerr << pmc << std::endl;
100
101     std::wcerr << std::endl;
102 }
103
104 logging::Logger & AnalysisVisitor::getLogger()
105 {
106     return logger;
107 }
108
109 bool AnalysisVisitor::asDouble(types::InternalType * pIT, double & out)
110 {
111     if (pIT && pIT->isDouble())
112     {
113         types::Double * pDbl = static_cast<types::Double *>(pIT);
114         if (!pDbl->isComplex() && pDbl->getSize() == 1)
115         {
116             out = pDbl->get()[0];
117             return true;
118         }
119     }
120
121     return false;
122 }
123
124 bool AnalysisVisitor::asDouble(ast::Exp & e, double & out)
125 {
126     if (e.isDoubleExp())
127     {
128         out = static_cast<ast::DoubleExp &>(e).getValue();
129         return true;
130     }
131     else if (e.isOpExp())
132     {
133         ast::OpExp & op = static_cast<ast::OpExp &>(e);
134         if (op.getOper() == ast::OpExp::unaryMinus)
135         {
136             if (op.getRight().isDoubleExp())
137             {
138                 out = -static_cast<ast::DoubleExp &>(op.getRight()).getValue();
139                 return true;
140             }
141         }
142         else if (op.getLeft().isDoubleExp() && op.getRight().isDoubleExp())
143         {
144             const double L = static_cast<ast::DoubleExp &>(op.getLeft()).getValue();
145             const double R = static_cast<ast::DoubleExp &>(op.getRight()).getValue();
146
147             switch (op.getOper())
148             {
149                 case ast::OpExp::minus:
150                     out = L - R;
151                     return true;
152                 case ast::OpExp::plus:
153                     out = L + R;
154                     return true;
155                 case ast::OpExp::times:
156                 case ast::OpExp::dottimes:
157                 case ast::OpExp::krontimes:
158                     out = L * R;
159                     return true;
160                 case ast::OpExp::rdivide:
161                 case ast::OpExp::dotrdivide:
162                 case ast::OpExp::kronrdivide:
163                     out = L / R;
164                     return true;
165                 case ast::OpExp::ldivide:
166                 case ast::OpExp::dotldivide:
167                 case ast::OpExp::kronldivide:
168                     out = R / L;
169                     return true;
170                 case ast::OpExp::power:
171                 case ast::OpExp::dotpower:
172                     out = std::pow(L, R);
173                     return true;
174                 default:
175                     return false;
176             }
177         }
178     }
179
180     return false;
181 }
182
183 bool AnalysisVisitor::isDoubleConstant(const ast::Exp & e)
184 {
185     if (e.isDoubleExp())
186     {
187         return true;
188     }
189     else if (e.isOpExp())
190     {
191         const ast::OpExp & oe = static_cast<const ast::OpExp &>(e);
192         if (!oe.isBooleanOp())
193         {
194             return isDoubleConstant(oe.getLeft()) && isDoubleConstant(oe.getRight());
195         }
196         return false;
197     }
198     else if (e.isMatrixExp())
199     {
200         const ast::MatrixExp & me = static_cast<const ast::MatrixExp &>(e);
201         const ast::exps_t & lines = me.getLines();
202         for (const auto line : lines)
203         {
204             const ast::exps_t & columns = static_cast<ast::MatrixLineExp *>(line)->getColumns();
205             for (const auto column : columns)
206             {
207                 if (column && !isDoubleConstant(*column))
208                 {
209                     return false;
210                 }
211             }
212         }
213         return true;
214     }
215     else if (e.isListExp())
216     {
217         const ast::ListExp & le = static_cast<const ast::ListExp &>(e);
218         return isDoubleConstant(le.getStart()) && isDoubleConstant(le.getStep()) && isDoubleConstant(le.getEnd());
219     }
220     else if (e.isSimpleVar())
221     {
222         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(e);
223         const symbol::Symbol & sym = var.getSymbol();
224         const std::wstring & name = sym.getName();
225         return name == L"%i" || name == L"%inf" || name == L"%nan" || name == L"%eps" || name == L"%pi" || name == L"%e";
226     }
227     else if (e.isCallExp())
228     {
229         const ast::CallExp & ce = static_cast<const ast::CallExp &>(e);
230         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(ce.getName());
231         const std::wstring & name = var.getSymbol().getName();
232
233         // TODO: check if 'ones' and 'zeros' are the expected functions
234         // ie: ones="abc"; ones(1) !!!
235         if (name == L"ones" || name == L"zeros")
236         {
237             const ast::exps_t args = ce.getArgs();
238             switch (args.size())
239             {
240                 case 0:
241                     return true;
242                 case 1:
243                     return isDoubleConstant(*args.front());
244                 case 2:
245                     return isDoubleConstant(*args.front()) && isDoubleConstant(**std::next(args.cbegin()));
246                 default:
247                     return false;
248             }
249         }
250     }
251
252     return false;
253 }
254
255 bool AnalysisVisitor::asDoubleMatrix(ast::Exp & e, types::Double *& data)
256 {
257     if (isDoubleConstant(e))
258     {
259         ast::ExecVisitor exec;
260         e.accept(exec);
261         types::InternalType * pIT = exec.getResult();
262         // TODO : handle complex case
263         if (pIT && pIT->isDouble() && !pIT->getAs<types::Double>()->isComplex())
264         {
265             pIT->IncreaseRef();
266             data = static_cast<types::Double *>(pIT);
267
268             return true;
269         }
270     }
271
272     return false;
273 }
274
275 void AnalysisVisitor::visitArguments(const std::wstring & name, const unsigned int lhs, const TIType & calltype, ast::CallExp & e, const ast::exps_t & args)
276 {
277     std::vector<Result> resargs;
278     std::vector<TIType> vargs;
279     vargs.reserve(args.size());
280     resargs.reserve(args.size());
281
282     const symbol::Symbol & sym = static_cast<ast::SimpleVar &>(e.getName()).getSymbol();
283     argIndices.emplace(sym, 0);
284
285     for (auto arg : args)
286     {
287         argIndices.top().getIndex() += 1;
288         arg->accept(*this);
289         resargs.push_back(getResult());
290         vargs.push_back(getResult().getType());
291     }
292
293     argIndices.pop();
294     uint64_t functionId = 0;
295     std::vector<TIType> out = getDM().call(*this, lhs, sym, vargs, &e, functionId);
296     if (lhs > 1)
297     {
298         std::vector<Result> resargs;
299         std::vector<TIType> vargs;
300         vargs.reserve(args.size());
301         resargs.reserve(args.size());
302
303         for (auto arg : args)
304         {
305             arg->accept(*this);
306             resargs.push_back(getResult());
307             vargs.push_back(getResult().getType());
308         }
309
310         const symbol::Symbol & sym = static_cast<ast::SimpleVar &>(e.getName()).getSymbol();
311         uint64_t functionId = 0;
312         std::vector<TIType> out = getDM().call(*this, lhs, sym, vargs, &e, functionId);
313         if (lhs > 1)
314         {
315             multipleLHS.clear();
316             multipleLHS.reserve(out.size());
317             for (const auto & type : out)
318             {
319                 const int tempId = getDM().getTmpId(type, false);
320                 multipleLHS.emplace_back(type, tempId);
321             }
322
323             auto i = args.begin();
324             for (const auto & resarg : resargs)
325             {
326                 getDM().releaseTmp(resarg.getTempId(), *i);
327                 ++i;
328             }
329         }
330         else if (lhs == 1)
331         {
332             int tempId = -1;
333             if (resargs.size() == 1)
334             {
335                 const int id = resargs.back().getTempId();
336                 if (id != -1 && Checkers::isElementWise(name) && out[0] == resargs.back().getType())
337                 {
338                     tempId = id;
339                 }
340             }
341             if (tempId == -1)
342             {
343                 tempId = getDM().getTmpId(out[0], false);
344                 auto i = args.begin();
345                 for (const auto & resarg : resargs)
346                 {
347                     getDM().releaseTmp(resarg.getTempId(), *i);
348                     ++i;
349                 }
350             }
351
352             e.getDecorator().res = Result(out[0], tempId, functionId);
353             e.getDecorator().setCall(name, vargs);
354             setResult(e.getDecorator().res);
355         }
356     }
357 }
358
359 int AnalysisVisitor::getTmpIdForEWOp(const TIType & resT, const Result & LR, const Result & RR, ast::Exp * Lexp, ast::Exp * Rexp)
360 {
361     int tempId = -1;
362     if (resT.isknown() && resT.ismatrix())
363     {
364         if (LR.isTemp() || RR.isTemp())
365         {
366             const int Lid = LR.getTempId();
367             const int Rid = RR.getTempId();
368             const TIType & LT = LR.getType();
369             const TIType & RT = RR.getType();
370
371             if (LT.isscalar())
372             {
373                 if (RT.isscalar())
374                 {
375                     if (Lid == -1)
376                     {
377                         if (resT == LT)
378                         {
379                             tempId = Rid;
380                         }
381                         else
382                         {
383                             tempId = getDM().getTmpId(resT, false);
384                             getDM().releaseTmp(Rid, Rexp);
385                         }
386                     }
387                     else
388                     {
389                         if (resT == LT)
390                         {
391                             tempId = Lid;
392                             getDM().releaseTmp(Rid, Rexp);
393                         }
394                         else if (Rid != -1 && resT == RT)
395                         {
396                             tempId = Rid;
397                             getDM().releaseTmp(Lid, Lexp);
398                         }
399                         else
400                         {
401                             tempId = getDM().getTmpId(resT, false);
402                             getDM().releaseTmp(Lid, Lexp);
403                         }
404                     }
405                 }
406                 else
407                 {
408                     if (Rid == -1)
409                     {
410                         tempId = getDM().getTmpId(resT, false);
411                     }
412                     else
413                     {
414                         if (resT == RT)
415                         {
416                             tempId = Rid;
417                         }
418                         else if (Lid != -1 && resT == LT)
419                         {
420                             tempId = Lid;
421                             getDM().releaseTmp(Rid, Rexp);
422                         }
423                         else
424                         {
425                             tempId = getDM().getTmpId(resT, false);
426                             getDM().releaseTmp(Rid, Rexp);
427                         }
428                     }
429                     getDM().releaseTmp(Lid, Lexp);
430                 }
431             }
432             else
433             {
434                 if (RT.isscalar())
435                 {
436                     if (Lid == -1)
437                     {
438                         tempId = getDM().getTmpId(resT, false);
439                     }
440                     else
441                     {
442                         if (resT == LT)
443                         {
444                             tempId = Lid;
445                         }
446                         else if (Rid != -1 && resT == RT)
447                         {
448                             tempId = Rid;
449                             getDM().releaseTmp(Lid, Lexp);
450                         }
451                         else
452                         {
453                             tempId = getDM().getTmpId(resT, false);
454                             getDM().releaseTmp(Lid, Lexp);
455                         }
456                     }
457                     getDM().releaseTmp(Rid, Rexp);
458                 }
459                 else
460                 {
461                     if (Rid == -1)
462                     {
463                         if (resT == LT)
464                         {
465                             tempId = Lid;
466                         }
467                         else
468                         {
469                             tempId = getDM().getTmpId(resT, false);
470                             getDM().releaseTmp(Lid, Lexp);
471                         }
472                     }
473                     else
474                     {
475                         if (resT == RT)
476                         {
477                             tempId = Rid;
478                         }
479                         else if (Lid != -1 && resT == LT)
480                         {
481                             tempId = Lid;
482                             getDM().releaseTmp(Rid, Rexp);
483                         }
484                         else
485                         {
486                             tempId = getDM().getTmpId(resT, false);
487                             getDM().releaseTmp(Rid, Rexp);
488                         }
489                         getDM().releaseTmp(Lid, Lexp);
490                     }
491                 }
492             }
493         }
494         else
495         {
496             tempId = getDM().getTmpId(resT, false);
497         }
498     }
499
500     return tempId;
501 }
502 }