8665212ed6334d46bb4f340ecf9fe1791ac0aebd
[scilab.git] / scilab / modules / ast / src / cpp / analysis / AnalysisVisitor.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include "symbol.hxx"
14
15 #include "AnalysisVisitor.hxx"
16 #include "analyzers/ArgnAnalyzer.hxx"
17 #include "analyzers/CeilAnalyzer.hxx"
18 #include "analyzers/DiagAnalyzer.hxx"
19 #include "analyzers/LengthAnalyzer.hxx"
20 #include "analyzers/MatrixAnalyzer.hxx"
21 #include "analyzers/MemInitAnalyzer.hxx"
22 #include "analyzers/SizeAnalyzer.hxx"
23 #include "analyzers/TypeAnalyzer.hxx"
24 #include "analyzers/TypeofAnalyzer.hxx"
25 #include "analyzers/InttypeAnalyzer.hxx"
26 #include "analyzers/IconvertAnalyzer.hxx"
27 #include "analyzers/IsrealAnalyzer.hxx"
28 #include "analyzers/IsscalarAnalyzer.hxx"
29 #include "analyzers/FindAnalyzer.hxx"
30
31 //#include "analyzers/SqrtAnalyzer.hxx"
32
33 namespace analysis
34 {
35     AnalysisVisitor::MapSymCall AnalysisVisitor::symscall = AnalysisVisitor::initCalls();//a=1:3;b=2;c=3;testAnalysis("repmat","a","b","c")
36
37     AnalysisVisitor::MapSymCall AnalysisVisitor::initCalls()
38     {
39         MapSymCall msc;
40
41         msc.emplace(L"zeros", std::shared_ptr<CallAnalyzer>(new ZerosAnalyzer()));
42         msc.emplace(L"ones", std::shared_ptr<CallAnalyzer>(new OnesAnalyzer()));
43         msc.emplace(L"rand", std::shared_ptr<CallAnalyzer>(new RandAnalyzer()));
44         msc.emplace(L"matrix", std::shared_ptr<CallAnalyzer>(new MatrixAnalyzer()));
45         msc.emplace(L"eye", std::shared_ptr<CallAnalyzer>(new EyeAnalyzer()));
46
47         std::shared_ptr<CallAnalyzer> ca(new CeilAnalyzer());
48         msc.emplace(L"ceil", ca);
49         msc.emplace(L"floor", ca);
50         msc.emplace(L"round", ca);
51         msc.emplace(L"fix", ca);
52         msc.emplace(L"int", ca);
53
54         //msc.emplace(L"sqrt", std::shared_ptr<CallAnalyzer>(new SqrtAnalyzer()));
55         msc.emplace(L"argn", std::shared_ptr<CallAnalyzer>(new ArgnAnalyzer()));
56         msc.emplace(L"size", std::shared_ptr<CallAnalyzer>(new SizeAnalyzer()));
57         msc.emplace(L"length", std::shared_ptr<CallAnalyzer>(new LengthAnalyzer()));
58         msc.emplace(L"diag", std::shared_ptr<CallAnalyzer>(new DiagAnalyzer()));
59         msc.emplace(L"type", std::shared_ptr<CallAnalyzer>(new TypeAnalyzer()));
60         msc.emplace(L"typeof", std::shared_ptr<CallAnalyzer>(new TypeofAnalyzer()));
61         msc.emplace(L"inttype", std::shared_ptr<CallAnalyzer>(new InttypeAnalyzer()));
62         msc.emplace(L"iconvert", std::shared_ptr<CallAnalyzer>(new IconvertAnalyzer()));
63         msc.emplace(L"isreal", std::shared_ptr<CallAnalyzer>(new IsrealAnalyzer()));
64         msc.emplace(L"isscalar", std::shared_ptr<CallAnalyzer>(new IsscalarAnalyzer()));
65         msc.emplace(L"find", std::shared_ptr<CallAnalyzer>(new FindAnalyzer()));
66
67         return msc;
68     }
69
70
71     AnalysisVisitor::AnalysisVisitor() : cv(*this), pv(std::wcerr, true, false), logger("/tmp/analysis.log")
72     {
73         start_chrono();
74     }
75
76     AnalysisVisitor::~AnalysisVisitor() { }
77
78     void AnalysisVisitor::reset()
79     {
80         _result = Result();
81         dm.reset();
82         multipleLHS.clear();
83         while (!loops.empty())
84         {
85             loops.pop();
86         }
87         start_chrono();
88     }
89
90     void AnalysisVisitor::print_info()
91     {
92         stop_chrono();
93
94         //std::wcout << getGVN() << std::endl << std::endl; function z=foo(x,y);z=argn(2);endfunction;jit("x=123;y=456;t=foo(x,y)")
95         std::wcerr << L"Analysis: " << *static_cast<Chrono *>(this) << std::endl;
96         //std::wcout << temp << std::endl;
97
98         std::wcerr << dm << std::endl;
99         std::wcerr << pmc << std::endl;
100
101         std::wcerr << std::endl;
102     }
103
104     logging::Logger & AnalysisVisitor::getLogger()
105     {
106         return logger;
107     }
108
109     bool AnalysisVisitor::asDouble(types::InternalType * pIT, double & out)
110     {
111         if (pIT && pIT->isDouble())
112         {
113             types::Double * pDbl = static_cast<types::Double *>(pIT);
114             if (!pDbl->isComplex() && pDbl->getSize() == 1)
115             {
116                 out = pDbl->get()[0];
117                 return true;
118             }
119         }
120
121         return false;
122     }
123
124     bool AnalysisVisitor::asDouble(ast::Exp & e, double & out)
125     {
126         if (e.isDoubleExp())
127         {
128             out = static_cast<ast::DoubleExp &>(e).getValue();
129             return true;
130         }
131         else if (e.isOpExp())
132         {
133             ast::OpExp & op = static_cast<ast::OpExp &>(e);
134             if (op.getOper() == ast::OpExp::unaryMinus)
135             {
136                 if (op.getRight().isDoubleExp())
137                 {
138                     out = -static_cast<ast::DoubleExp &>(op.getRight()).getValue();
139                     return true;
140                 }
141             }
142             else if (op.getLeft().isDoubleExp() && op.getRight().isDoubleExp())
143             {
144                 const double L = static_cast<ast::DoubleExp &>(op.getLeft()).getValue();
145                 const double R = static_cast<ast::DoubleExp &>(op.getRight()).getValue();
146
147                 switch (op.getOper())
148                 {
149                 case ast::OpExp::minus:
150                     out = L - R;
151                     return true;
152                 case ast::OpExp::plus:
153                     out = L + R;
154                     return true;
155                 case ast::OpExp::times:
156                 case ast::OpExp::dottimes:
157                 case ast::OpExp::krontimes:
158                     out = L * R;
159                     return true;
160                 case ast::OpExp::rdivide:
161                 case ast::OpExp::dotrdivide:
162                 case ast::OpExp::kronrdivide:
163                     out = L / R;
164                     return true;
165                 case ast::OpExp::ldivide:
166                 case ast::OpExp::dotldivide:
167                 case ast::OpExp::kronldivide:
168                     out = R / L;
169                     return true;
170                 case ast::OpExp::power:
171                 case ast::OpExp::dotpower:
172                     out = std::pow(L, R);
173                     return true;
174                 default:
175                     return false;
176                 }
177             }
178         }
179
180         return false;
181     }
182
183     bool AnalysisVisitor::isDoubleConstant(const ast::Exp & e)
184     {
185         if (e.isDoubleExp())
186         {
187             return true;
188         }
189         else if (e.isOpExp())
190         {
191             const ast::OpExp & oe = static_cast<const ast::OpExp &>(e);
192             if (!oe.isBooleanOp())
193             {
194                 return isDoubleConstant(oe.getLeft()) && isDoubleConstant(oe.getRight());
195             }
196             return false;
197         }
198         else if (e.isMatrixExp())
199         {
200             const ast::MatrixExp & me = static_cast<const ast::MatrixExp &>(e);
201             const ast::exps_t & lines = me.getLines();
202             for (const auto line : lines)
203             {
204                 const ast::exps_t & columns = static_cast<ast::MatrixLineExp *>(line)->getColumns();
205                 for (const auto column : columns)
206                 {
207                     if (column && !isDoubleConstant(*column))
208                     {
209                         return false;
210                     }
211                 }
212             }
213             return true;
214         }
215         else if (e.isListExp())
216         {
217             const ast::ListExp & le = static_cast<const ast::ListExp &>(e);
218             return isDoubleConstant(le.getStart()) && isDoubleConstant(le.getStep()) && isDoubleConstant(le.getEnd());
219         }
220         else if (e.isSimpleVar())
221         {
222             const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(e);
223             const symbol::Symbol & sym = var.getSymbol();
224             const std::wstring & name = sym.getName();
225             return name == L"%i" || name == L"%inf" || name == L"%nan" || name == L"%eps" || name == L"%pi" || name == L"%e";
226         }
227         else if (e.isCallExp())
228         {
229             const ast::CallExp & ce = static_cast<const ast::CallExp &>(e);
230             const ast::SimpleVar & var = static_cast<const ast::SimpleVar &>(ce.getName());
231             const std::wstring & name = var.getSymbol().getName();
232
233             // TODO: check if 'ones' and 'zeros' are the expected functions
234             // ie: ones="abc"; ones(1) !!!
235             if (name == L"ones" || name == L"zeros")
236             {
237                 const ast::exps_t args = ce.getArgs();
238                 switch (args.size())
239                 {
240                 case 0:
241                     return true;
242                 case 1:
243                     return isDoubleConstant(*args.front());
244                 case 2:
245                     return isDoubleConstant(*args.front()) && isDoubleConstant(**std::next(args.cbegin()));
246                 default:
247                     return false;
248                 }
249             }
250         }
251
252         return false;
253     }
254
255     bool AnalysisVisitor::asDoubleMatrix(ast::Exp & e, types::Double *& data)
256     {
257         if (isDoubleConstant(e))
258         {
259             ast::ExecVisitor exec;
260             e.accept(exec);
261             types::InternalType * pIT = exec.getResult();
262             // TODO : handle complex case
263             if (pIT && pIT->isDouble() && !pIT->getAs<types::Double>()->isComplex())
264             {
265                 pIT->IncreaseRef();
266                 data = static_cast<types::Double *>(pIT);
267
268                 return true;
269             }
270         }
271
272         return false;
273     }
274
275     void AnalysisVisitor::visitArguments(const std::wstring & name, const unsigned int lhs, const TIType & calltype, ast::CallExp & e, const ast::exps_t & args)
276     {
277         std::vector<Result> resargs;
278         std::vector<TIType> vargs;
279         vargs.reserve(args.size());
280         resargs.reserve(args.size());
281
282         for (auto arg : args)
283         {
284             arg->accept(*this);
285             resargs.push_back(getResult());
286             vargs.push_back(getResult().getType());
287         }
288
289         const symbol::Symbol & sym = static_cast<ast::SimpleVar &>(e.getName()).getSymbol();
290         uint64_t functionId = 0;
291         std::vector<TIType> out = getDM().call(*this, lhs, sym, vargs, &e, functionId);
292         if (lhs > 1)
293         {
294             multipleLHS.clear();
295             multipleLHS.reserve(out.size());
296             for (const auto & type : out)
297             {
298                 const int tempId = getDM().getTmpId(type, false);
299                 multipleLHS.emplace_back(type, tempId);
300             }
301
302             auto i = args.begin();
303             for (const auto & resarg : resargs)
304             {
305                 getDM().releaseTmp(resarg.getTempId(), *i);
306                 ++i;
307             }
308         }
309         else if (lhs == 1)
310         {
311             int tempId = -1;
312             if (resargs.size() == 1)
313             {
314                 const int id = resargs.back().getTempId();
315                 if (id != -1 && Checkers::isElementWise(name) && out[0] == resargs.back().getType())
316                 {
317                     tempId = id;
318                 }
319             }
320             if (tempId == -1)
321             {
322                 tempId = getDM().getTmpId(out[0], false);
323                 auto i = args.begin();
324                 for (const auto & resarg : resargs)
325                 {
326                     getDM().releaseTmp(resarg.getTempId(), *i);
327                     ++i;
328                 }
329             }
330
331             e.getDecorator().res = Result(out[0], tempId, functionId);
332             e.getDecorator().setCall(name, vargs);
333             setResult(e.getDecorator().res);
334         }
335     }
336
337     int AnalysisVisitor::getTmpIdForEWOp(const TIType & resT, const Result & LR, const Result & RR, ast::Exp * Lexp, ast::Exp * Rexp)
338     {
339         int tempId = -1;
340         if (resT.isknown() && resT.ismatrix())
341         {
342             if (LR.isTemp() || RR.isTemp())
343             {
344                 const int Lid = LR.getTempId();
345                 const int Rid = RR.getTempId();
346                 const TIType & LT = LR.getType();
347                 const TIType & RT = RR.getType();
348
349                 if (LT.isscalar())
350                 {
351                     if (RT.isscalar())
352                     {
353                         if (Lid == -1)
354                         {
355                             if (resT == LT)
356                             {
357                                 tempId = Rid;
358                             }
359                             else
360                             {
361                                 tempId = getDM().getTmpId(resT, false);
362                                 getDM().releaseTmp(Rid, Rexp);
363                             }
364                         }
365                         else
366                         {
367                             if (resT == LT)
368                             {
369                                 tempId = Lid;
370                                 getDM().releaseTmp(Rid, Rexp);
371                             }
372                             else if (Rid != -1 && resT == RT)
373                             {
374                                 tempId = Rid;
375                                 getDM().releaseTmp(Lid, Lexp);
376                             }
377                             else
378                             {
379                                 tempId = getDM().getTmpId(resT, false);
380                                 getDM().releaseTmp(Lid, Lexp);
381                             }
382                         }
383                     }
384                     else
385                     {
386                         if (Rid == -1)
387                         {
388                             tempId = getDM().getTmpId(resT, false);
389                         }
390                         else
391                         {
392                             if (resT == RT)
393                             {
394                                 tempId = Rid;
395                             }
396                             else if (Lid != -1 && resT == LT)
397                             {
398                                 tempId = Lid;
399                                 getDM().releaseTmp(Rid, Rexp);
400                             }
401                             else
402                             {
403                                 tempId = getDM().getTmpId(resT, false);
404                                 getDM().releaseTmp(Rid, Rexp);
405                             }
406                         }
407                         getDM().releaseTmp(Lid, Lexp);
408                     }
409                 }
410                 else
411                 {
412                     if (RT.isscalar())
413                     {
414                         if (Lid == -1)
415                         {
416                             tempId = getDM().getTmpId(resT, false);
417                         }
418                         else
419                         {
420                             if (resT == LT)
421                             {
422                                 tempId = Lid;
423                             }
424                             else if (Rid != -1 && resT == RT)
425                             {
426                                 tempId = Rid;
427                                 getDM().releaseTmp(Lid, Lexp);
428                             }
429                             else
430                             {
431                                 tempId = getDM().getTmpId(resT, false);
432                                 getDM().releaseTmp(Lid, Lexp);
433                             }
434                         }
435                         getDM().releaseTmp(Rid, Rexp);
436                     }
437                     else
438                     {
439                         if (Rid == -1)
440                         {
441                             if (resT == LT)
442                             {
443                                 tempId = Lid;
444                             }
445                             else
446                             {
447                                 tempId = getDM().getTmpId(resT, false);
448                                 getDM().releaseTmp(Lid, Lexp);
449                             }
450                         }
451                         else
452                         {
453                             if (resT == RT)
454                             {
455                                 tempId = Rid;
456                             }
457                             else if (Lid != -1 && resT == LT)
458                             {
459                                 tempId = Lid;
460                                 getDM().releaseTmp(Rid, Rexp);
461                             }
462                             else
463                             {
464                                 tempId = getDM().getTmpId(resT, false);
465                                 getDM().releaseTmp(Rid, Rexp);
466                             }
467                             getDM().releaseTmp(Lid, Lexp);
468                         }
469                     }
470                 }
471             }
472             else
473             {
474                 tempId = getDM().getTmpId(resT, false);
475             }
476         }
477
478         return tempId;
479     }
480 }