dfd048e05727a9c34815e47d0b68da29996d8340
[scilab.git] / scilab / modules / ast / src / cpp / analysis / AnalysisVisitor.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
4  *
5  * Copyright (C) 2012 - 2016 - Scilab Enterprises
6  *
7  * This file is hereby licensed under the terms of the GNU GPL v2.0,
8  * pursuant to article 5.3.4 of the CeCILL v.2.1.
9  * This file was originally licensed under the terms of the CeCILL v2.1,
10  * and continues to be available under such terms.
11  * For more information, see the COPYING file which you should have received
12  * along with this program.
13  *
14  */
15
16 #include "symbol.hxx"
17
18 #include "AnalysisVisitor.hxx"
19 #include "analyzers/ArgnAnalyzer.hxx"
20 #include "analyzers/CeilAnalyzer.hxx"
21 #include "analyzers/DiagAnalyzer.hxx"
22 #include "analyzers/LengthAnalyzer.hxx"
23 #include "analyzers/MatrixAnalyzer.hxx"
24 #include "analyzers/MemInitAnalyzer.hxx"
25 #include "analyzers/SizeAnalyzer.hxx"
26 #include "analyzers/TypeAnalyzer.hxx"
27 #include "analyzers/TypeofAnalyzer.hxx"
28 #include "analyzers/InttypeAnalyzer.hxx"
29 #include "analyzers/IconvertAnalyzer.hxx"
30 #include "analyzers/IsrealAnalyzer.hxx"
31 #include "analyzers/IsscalarAnalyzer.hxx"
32 #include "analyzers/FindAnalyzer.hxx"
33
34 //#include "analyzers/SqrtAnalyzer.hxx"
35
36 namespace analysis
37 {
38 AnalysisVisitor* AnalysisVisitor::m_instance = nullptr;
39 AnalysisVisitor::MapSymCall AnalysisVisitor::symscall = AnalysisVisitor::initCalls();//a=1:3;b=2;c=3;testAnalysis("repmat","a","b","c")
40
41 AnalysisVisitor& AnalysisVisitor::getInstance()
42 {
43     if (m_instance == nullptr)
44     {
45         m_instance = new AnalysisVisitor();
46     }
47
48     return *m_instance;
49 }
50
51 void AnalysisVisitor::deleteInstance()
52 {
53     if (m_instance) 
54     {
55         delete m_instance;
56         m_instance = nullptr;
57     }
58 }
59
60 AnalysisVisitor::MapSymCall AnalysisVisitor::initCalls()
61 {
62     MapSymCall msc;
63
64     msc.emplace(L"zeros", std::shared_ptr<CallAnalyzer>(new ZerosAnalyzer()));
65     msc.emplace(L"ones", std::shared_ptr<CallAnalyzer>(new OnesAnalyzer()));
66     msc.emplace(L"rand", std::shared_ptr<CallAnalyzer>(new RandAnalyzer()));
67     msc.emplace(L"matrix", std::shared_ptr<CallAnalyzer>(new MatrixAnalyzer()));
68     msc.emplace(L"eye", std::shared_ptr<CallAnalyzer>(new EyeAnalyzer()));
69
70     std::shared_ptr<CallAnalyzer> ca(new CeilAnalyzer());
71     msc.emplace(L"ceil", ca);
72     msc.emplace(L"floor", ca);
73     msc.emplace(L"round", ca);
74     msc.emplace(L"fix", ca);
75     msc.emplace(L"int", ca);
76
77     //msc.emplace(L"sqrt", std::shared_ptr<CallAnalyzer>(new SqrtAnalyzer()));
78     msc.emplace(L"argn", std::shared_ptr<CallAnalyzer>(new ArgnAnalyzer()));
79     msc.emplace(L"size", std::shared_ptr<CallAnalyzer>(new SizeAnalyzer()));
80     msc.emplace(L"length", std::shared_ptr<CallAnalyzer>(new LengthAnalyzer()));
81     msc.emplace(L"diag", std::shared_ptr<CallAnalyzer>(new DiagAnalyzer()));
82     msc.emplace(L"type", std::shared_ptr<CallAnalyzer>(new TypeAnalyzer()));
83     msc.emplace(L"typeof", std::shared_ptr<CallAnalyzer>(new TypeofAnalyzer()));
84     msc.emplace(L"inttype", std::shared_ptr<CallAnalyzer>(new InttypeAnalyzer()));
85     msc.emplace(L"iconvert", std::shared_ptr<CallAnalyzer>(new IconvertAnalyzer()));
86     msc.emplace(L"isreal", std::shared_ptr<CallAnalyzer>(new IsrealAnalyzer()));
87     msc.emplace(L"isscalar", std::shared_ptr<CallAnalyzer>(new IsscalarAnalyzer()));
88     msc.emplace(L"find", std::shared_ptr<CallAnalyzer>(new FindAnalyzer()));
89
90     return msc;
91 }
92
93
94 AnalysisVisitor::AnalysisVisitor() : cv(*this), pv(std::wcerr, true, false), logger("/tmp/analysis.log")
95 {
96     start_chrono();
97 }
98
99 AnalysisVisitor::~AnalysisVisitor() { }
100
101 void AnalysisVisitor::reset()
102 {
103     _result = Result();
104     dm.reset();
105     multipleLHS.clear();
106     while (!loops.empty())
107     {
108         loops.pop();
109     }
110
111     fblockListeners.clear();
112     start_chrono();
113 }
114
115 void AnalysisVisitor::print_info()
116 {
117     stop_chrono();
118
119     //std::wcout << getGVN() << std::endl << std::endl; function z=foo(x,y);z=argn(2);endfunction;jit("x=123;y=456;t=foo(x,y)")
120     std::wcerr << L"Analysis: " << *static_cast<Chrono *>(this) << std::endl;
121     //std::wcout << temp << std::endl;
122
123     std::wcerr << dm << std::endl;
124     std::wcerr << pmc << std::endl;
125
126     std::wcerr << std::endl;
127 }
128
129 logging::Logger & AnalysisVisitor::getLogger()
130 {
131     return logger;
132 }
133
134 bool AnalysisVisitor::asDouble(types::InternalType * pIT, double & out)
135 {
136     if (pIT && pIT->isDouble())
137     {
138         types::Double * pDbl = static_cast<types::Double *>(pIT);
139         if (!pDbl->isComplex() && pDbl->getSize() == 1)
140         {
141             out = pDbl->get()[0];
142             return true;
143         }
144     }
145
146     return false;
147 }
148
149 bool AnalysisVisitor::asDouble(ast::Exp & e, double & out)
150 {
151     if (e.isDoubleExp())
152     {
153         out = static_cast<ast::DoubleExp &>(e).getValue();
154         return true;
155     }
156     else if (e.isOpExp())
157     {
158         ast::OpExp & op = static_cast<ast::OpExp &>(e);
159         if (op.getOper() == ast::OpExp::unaryMinus)
160         {
161             if (op.getRight().isDoubleExp())
162             {
163                 out = -static_cast<ast::DoubleExp &>(op.getRight()).getValue();
164                 return true;
165             }
166         }
167         else if (op.getLeft().isDoubleExp() && op.getRight().isDoubleExp())
168         {
169             const double L = static_cast<ast::DoubleExp &>(op.getLeft()).getValue();
170             const double R = static_cast<ast::DoubleExp &>(op.getRight()).getValue();
171
172             switch (op.getOper())
173             {
174                 case ast::OpExp::minus:
175                     out = L - R;
176                     return true;
177                 case ast::OpExp::plus:
178                     out = L + R;
179                     return true;
180                 case ast::OpExp::times:
181                 case ast::OpExp::dottimes:
182                 case ast::OpExp::krontimes:
183                     out = L * R;
184                     return true;
185                 case ast::OpExp::rdivide:
186                 case ast::OpExp::dotrdivide:
187                 case ast::OpExp::kronrdivide:
188                     out = L / R;
189                     return true;
190                 case ast::OpExp::ldivide:
191                 case ast::OpExp::dotldivide:
192                 case ast::OpExp::kronldivide:
193                     out = R / L;
194                     return true;
195                 case ast::OpExp::power:
196                 case ast::OpExp::dotpower:
197                     out = std::pow(L, R);
198                     return true;
199                 default:
200                     return false;
201             }
202         }
203     }
204
205     return false;
206 }
207
208 bool AnalysisVisitor::isDoubleConstant(const ast::Exp & e)
209 {
210     if (e.isDoubleExp())
211     {
212         return true;
213     }
214     else if (e.isOpExp())
215     {
216         const ast::OpExp & oe = static_cast<const ast::OpExp &>(e);
217         if (!oe.isBooleanOp())
218         {
219             return isDoubleConstant(oe.getLeft()) && isDoubleConstant(oe.getRight());
220         }
221         return false;
222     }
223     else if (e.isMatrixExp())
224     {
225         const ast::MatrixExp & me = static_cast<const ast::MatrixExp &>(e);
226         const ast::exps_t & lines = me.getLines();
227         for (const auto line : lines)
228         {
229             const ast::exps_t & columns = static_cast<ast::MatrixLineExp *>(line)->getColumns();
230             for (const auto column : columns)
231             {
232                 if (column && !isDoubleConstant(*column))
233                 {
234                     return false;
235                 }
236             }
237         }
238         return true;
239     }
240     else if (e.isListExp())
241     {
242         const ast::ListExp & le = static_cast<const ast::ListExp &>(e);
243         return isDoubleConstant(le.getStart()) && isDoubleConstant(le.getStep()) && isDoubleConstant(le.getEnd());
244     }
245     else if (e.isSimpleVar())
246     {
247         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(e);
248         const symbol::Symbol & sym = var.getSymbol();
249         const std::wstring & name = sym.getName();
250         return name == L"%i" || name == L"%inf" || name == L"%nan" || name == L"%eps" || name == L"%pi" || name == L"%e";
251     }
252     else if (e.isCallExp())
253     {
254         const ast::CallExp & ce = static_cast<const ast::CallExp &>(e);
255         const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(ce.getName());
256         const std::wstring & name = var.getSymbol().getName();
257
258         // TODO: check if 'ones' and 'zeros' are the expected functions
259         // ie: ones="abc"; ones(1) !!!
260         if (name == L"ones" || name == L"zeros")
261         {
262             const ast::exps_t args = ce.getArgs();
263             switch (args.size())
264             {
265                 case 0:
266                     return true;
267                 case 1:
268                     return isDoubleConstant(*args.front());
269                 case 2:
270                     return isDoubleConstant(*args.front()) && isDoubleConstant(**std::next(args.cbegin()));
271                 default:
272                     return false;
273             }
274         }
275     }
276
277     return false;
278 }
279
280 bool AnalysisVisitor::asDoubleMatrix(ast::Exp & e, types::Double *& data)
281 {
282     if (isDoubleConstant(e))
283     {
284         ast::ExecVisitor exec;
285         e.accept(exec);
286         types::InternalType * pIT = exec.getResult();
287         // TODO : handle complex case
288         if (pIT && pIT->isDouble() && !pIT->getAs<types::Double>()->isComplex())
289         {
290             pIT->IncreaseRef();
291             data = static_cast<types::Double *>(pIT);
292
293             return true;
294         }
295     }
296
297     return false;
298 }
299
300 void AnalysisVisitor::visitArguments(const std::wstring & name, const unsigned int lhs, const TIType & calltype, ast::CallExp & e, const ast::exps_t & args)
301 {
302     std::vector<Result> resargs;
303     std::vector<TIType> vargs;
304     vargs.reserve(args.size());
305     resargs.reserve(args.size());
306
307     const ast::SimpleVar & var = static_cast<ast::SimpleVar &>(e.getName());
308     const symbol::Symbol & sym = var.getSymbol();
309     argIndices.emplace(var, args.size(), 0);
310
311     for (auto arg : args)
312     {
313         argIndices.top().getIndex() += 1;
314         arg->accept(*this);
315         resargs.push_back(getResult());
316         vargs.push_back(getResult().getType());
317     }
318
319     argIndices.pop();
320     uint64_t functionId = 0;
321     std::vector<TIType> out = getDM().call(*this, lhs, sym, vargs, &e, functionId);
322     if (lhs > 1)
323     {
324         multipleLHS.clear();
325         multipleLHS.reserve(out.size());
326         for (const auto & type : out)
327         {
328             const int tempId = getDM().getTmpId(type, false);
329             multipleLHS.emplace_back(type, tempId);
330         }
331
332         auto i = args.begin();
333         for (const auto & resarg : resargs)
334         {
335             getDM().releaseTmp(resarg.getTempId(), *i);
336             ++i;
337         }
338     }
339     else if (lhs == 1)
340     {
341         int tempId = -1;
342         if (resargs.size() == 1)
343         {
344             const int id = resargs.back().getTempId();
345             if (id != -1 && Checkers::isElementWise(name) && out[0] == resargs.back().getType())
346             {
347                 tempId = id;
348             }
349         }
350         if (tempId == -1)
351         {
352             tempId = getDM().getTmpId(out[0], false);
353             auto i = args.begin();
354             for (const auto & resarg : resargs)
355             {
356                 getDM().releaseTmp(resarg.getTempId(), *i);
357                 ++i;
358             }
359         }
360
361         e.getDecorator().res = Result(out[0], tempId, functionId);
362         e.getDecorator().setCall(name, vargs);
363         setResult(e.getDecorator().res);
364     }
365 }
366
367 int AnalysisVisitor::getTmpIdForEWOp(const TIType & resT, const Result & LR, const Result & RR, ast::Exp * Lexp, ast::Exp * Rexp)
368 {
369     int tempId = -1;
370     if (resT.isknown() && resT.ismatrix())
371     {
372         if (LR.isTemp() || RR.isTemp())
373         {
374             const int Lid = LR.getTempId();
375             const int Rid = RR.getTempId();
376             const TIType & LT = LR.getType();
377             const TIType & RT = RR.getType();
378
379             if (LT.isscalar())
380             {
381                 if (RT.isscalar())
382                 {
383                     if (Lid == -1)
384                     {
385                         if (resT == LT)
386                         {
387                             tempId = Rid;
388                         }
389                         else
390                         {
391                             tempId = getDM().getTmpId(resT, false);
392                             getDM().releaseTmp(Rid, Rexp);
393                         }
394                     }
395                     else
396                     {
397                         if (resT == LT)
398                         {
399                             tempId = Lid;
400                             getDM().releaseTmp(Rid, Rexp);
401                         }
402                         else if (Rid != -1 && resT == RT)
403                         {
404                             tempId = Rid;
405                             getDM().releaseTmp(Lid, Lexp);
406                         }
407                         else
408                         {
409                             tempId = getDM().getTmpId(resT, false);
410                             getDM().releaseTmp(Lid, Lexp);
411                         }
412                     }
413                 }
414                 else
415                 {
416                     if (Rid == -1)
417                     {
418                         tempId = getDM().getTmpId(resT, false);
419                     }
420                     else
421                     {
422                         if (resT == RT)
423                         {
424                             tempId = Rid;
425                         }
426                         else if (Lid != -1 && resT == LT)
427                         {
428                             tempId = Lid;
429                             getDM().releaseTmp(Rid, Rexp);
430                         }
431                         else
432                         {
433                             tempId = getDM().getTmpId(resT, false);
434                             getDM().releaseTmp(Rid, Rexp);
435                         }
436                     }
437                     getDM().releaseTmp(Lid, Lexp);
438                 }
439             }
440             else
441             {
442                 if (RT.isscalar())
443                 {
444                     if (Lid == -1)
445                     {
446                         tempId = getDM().getTmpId(resT, false);
447                     }
448                     else
449                     {
450                         if (resT == LT)
451                         {
452                             tempId = Lid;
453                         }
454                         else if (Rid != -1 && resT == RT)
455                         {
456                             tempId = Rid;
457                             getDM().releaseTmp(Lid, Lexp);
458                         }
459                         else
460                         {
461                             tempId = getDM().getTmpId(resT, false);
462                             getDM().releaseTmp(Lid, Lexp);
463                         }
464                     }
465                     getDM().releaseTmp(Rid, Rexp);
466                 }
467                 else
468                 {
469                     if (Rid == -1)
470                     {
471                         if (resT == LT)
472                         {
473                             tempId = Lid;
474                         }
475                         else
476                         {
477                             tempId = getDM().getTmpId(resT, false);
478                             getDM().releaseTmp(Lid, Lexp);
479                         }
480                     }
481                     else
482                     {
483                         if (resT == RT)
484                         {
485                             tempId = Rid;
486                         }
487                         else if (Lid != -1 && resT == LT)
488                         {
489                             tempId = Lid;
490                             getDM().releaseTmp(Rid, Rexp);
491                         }
492                         else
493                         {
494                             tempId = getDM().getTmpId(resT, false);
495                             getDM().releaseTmp(Rid, Rexp);
496                         }
497                         getDM().releaseTmp(Lid, Lexp);
498                     }
499                 }
500             }
501         }
502         else
503         {
504             tempId = getDM().getTmpId(resT, false);
505         }
506     }
507
508     return tempId;
509 }
510 }