Analysis: fix bug in OperAnalyzer
[scilab.git] / scilab / modules / ast / src / cpp / analysis / OperAnalyzer.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2015 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include "AnalysisVisitor.hxx"
14 #include "analyzers/OperAnalyzer.hxx"
15 #include "allexp.hxx"
16 #include "allvar.hxx"
17 #include "alltypes.hxx"
18
19 namespace analysis
20 {
21 bool OperAnalyzer::analyze(AnalysisVisitor & visitor, ast::Exp & e)
22 {
23     ast::OpExp & oe = static_cast<ast::OpExp &>(e);
24     const ast::OpExp::Oper oper = oe.getOper();
25     if (oper == ast::OpExp::plus || oper == ast::OpExp::minus || oper == ast::OpExp::times)
26     {
27         if (ast::MemfillExp * mfe = analyzeMemfill(visitor, oe))
28         {
29             mfe->setVerbose(e.isVerbose());
30             e.replace(mfe);
31
32             return true;
33         }
34     }
35
36     /*if (ast::ExtendedOpExp * eoe = analyzeMemfill(visitor, oe))
37       {
38       eoe->setVerbose(e.isVerbose());
39       e.replace(eoe);
40
41       return true;
42       }*/
43
44     return false;
45 }
46
47 /*ast::ExtendedOpExp * OperAnalyzer::analyzeTransposedArgs(ast::OpExp & oe)
48   {
49   ast::Exp & L = oe.getLeft();
50   ast::Exp & R = oe.getRight();
51   ast::ExtendedOpExp::OP Lop, Rop;
52   ast::Exp * Le = &L;
53   ast::Exp * Re = &R;
54
55   if (L.isTransposeExp())
56   {
57   ast::TransposeExp & te = static_cast<ast::TransposeExp &>(L);
58   if (te.getConjugate() == ast::TransposeExp::_Conjugate_)
59   {
60   Lop = ast::ExtendedOpExp::ADJOINT;
61   }
62   else
63   {
64   Lop = ast::ExtendedOpExp::TRANSP;
65   }
66   Le = &te.getExp();
67   }
68
69   if (R.isTransposeExp())
70   {
71   ast::TransposeExp & te = static_cast<ast::TransposeExp &>(R);
72   if (te.getConjugate() == ast::TransposeExp::_Conjugate_)
73   {
74   Rop = ast::ExtendedOpExp::ADJOINT;
75   }
76   else
77   {
78   Rop = ast::ExtendedOpExp::TRANSP;
79   }
80   Re = &te.getExp();
81   }
82
83   if (Lop != ast::ExtendedOpExp::NONE || Rop != ast::ExtendedOpExp::NONE)
84   {
85   return new ast::ExtendedOpExp(oe.getLocation(), *Le, Lop, oe.getOper(), *Re, Rop);
86   }
87
88   return nullptr;
89   }*/
90
91 ast::MemfillExp * OperAnalyzer::analyzeMemfill(AnalysisVisitor & visitor, ast::OpExp & oe)
92 {
93     const ast::OpExp::Oper oper = oe.getOper();
94     ast::Exp & L = oe.getLeft();
95     ast::Exp & R = oe.getRight();
96
97     ast::Exp * constant = nullptr;
98     ast::MemfillExp * me = nullptr;
99     ast::exps_t args;
100     double value;
101     bool callAtLeft;
102
103     // We try to match something like A +* ones(...) or A +* zeros(...)
104     if (L.isCallExp())
105     {
106         ast::CallExp & ce = static_cast<ast::CallExp &>(L);
107         if (ce.getName().isSimpleVar())
108         {
109             const std::wstring & name = static_cast<ast::SimpleVar &>(ce.getName()).getSymbol().getName();
110             if (name == L"ones")
111             {
112                 value = 1;
113                 args = ce.getArgs();
114                 constant = &R;
115                 callAtLeft = true;
116             }
117             else if (name == L"zeros")
118             {
119                 value = 0;
120                 args = ce.getArgs();
121                 constant = &R;
122                 callAtLeft = true;
123             }
124         }
125     }
126
127     if (!constant)
128     {
129         if (R.isCallExp())
130         {
131             ast::CallExp & ce = static_cast<ast::CallExp &>(R);
132             if (ce.getName().isSimpleVar())
133             {
134                 const std::wstring & name = static_cast<ast::SimpleVar &>(ce.getName()).getSymbol().getName();
135                 if (name == L"ones")
136                 {
137                     value = 1;
138                     args = ce.getArgs();
139                     constant = &L;
140                     callAtLeft = false;
141                 }
142                 else if (name == L"zeros")
143                 {
144                     value = 0;
145                     args = ce.getArgs();
146                     constant = &L;
147                     callAtLeft = false;
148                 }
149             }
150         }
151     }
152
153     if (constant && (oper == ast::OpExp::plus || oper == ast::OpExp::minus || oper == ast::OpExp::times))
154     {
155         Result & res = constant->getDecorator().getResult();
156         if (res.getType().ismatrix() && res.getType().isscalar())
157         {
158             TIType ty(visitor.getGVN(), TIType::DOUBLE);
159             ast::exps_t cloneArgs;
160             cloneArgs.reserve(args.size());
161             for (auto arg : args)
162             {
163                 ast::Exp * cl = arg->clone();
164                 cl->getDecorator().setResult(arg->getDecorator().getResult());
165                 cloneArgs.push_back(cl);
166             }
167
168             switch (oper)
169             {
170                 case ast::OpExp::plus :
171                 {
172                     // plus is commutative so callAtLeft is ignored
173                     // we have something like x + ones(...) => it is a fill with x+1
174                     const Location & loc = oe.getLocation();
175                     ast::Exp * valExp;
176                     double x;
177                     if (res.getConstant().getDblValue(x))
178                     {
179                         valExp = new ast::DoubleExp(loc, new types::Double(x + value));
180                         valExp->getDecorator().setResult(Result(ty));
181                     }
182                     else
183                     {
184                         ast::Exp * cl = constant->clone();
185                         cl->getDecorator().setResult(constant->getDecorator().getResult());
186                         valExp = new ast::DoubleExp(loc, new types::Double(value));
187                         valExp->getDecorator().setResult(Result(ty));
188                         valExp = new ast::OpExp(loc, *cl, ast::OpExp::plus, *valExp);
189                         valExp->getDecorator().setResult(Result(Checkers::check_____add____(visitor.getGVN(), cl->getDecorator().getResult().getType(), ty)));
190                     }
191                     me = new ast::MemfillExp(loc, *valExp, cloneArgs);
192                     break;
193                 }
194                 case ast::OpExp::minus :
195                 {
196                     // we have something like x - ones(...) => it is a fill with x-1
197                     const Location & loc = oe.getLocation();
198                     ast::Exp * valExp;
199                     double x;
200                     if (res.getConstant().getDblValue(x))
201                     {
202                         valExp = new ast::DoubleExp(loc, new types::Double(callAtLeft ? value - x : x - value));
203                         valExp->getDecorator().setResult(Result(ty));
204                     }
205                     else
206                     {
207                         ast::Exp * cl = constant->clone();
208                         cl->getDecorator().setResult(constant->getDecorator().getResult());
209                         valExp = new ast::DoubleExp(loc, new types::Double(value));
210                         valExp->getDecorator().setResult(Result(ty));
211                         if (callAtLeft)
212                         {
213                             valExp = new ast::OpExp(loc, *valExp, ast::OpExp::minus, *cl);
214                             valExp->getDecorator().setResult(Result(Checkers::check_____sub____(visitor.getGVN(), ty, cl->getDecorator().getResult().getType())));
215                         }
216                         else
217                         {
218                             valExp = new ast::OpExp(loc, *cl, ast::OpExp::minus, *valExp);
219                             valExp->getDecorator().setResult(Result(Checkers::check_____sub____(visitor.getGVN(), cl->getDecorator().getResult().getType(), ty)));
220                         }
221                     }
222                     me = new ast::MemfillExp(loc, *valExp, cloneArgs);
223                     break;
224                 }
225                 case ast::OpExp::times :
226                 {
227                     // times is commutative so callAtLeft is ignored
228                     // we have something like x * ones(...) => it is a fill with x
229                     const Location & loc = oe.getLocation();
230                     ast::Exp * valExp;
231                     double x = 0;
232                     if (value == 0 || res.getConstant().getDblValue(x))
233                     {
234                         valExp = new ast::DoubleExp(loc, new types::Double(x * value));
235                         valExp->getDecorator().setResult(Result(ty));
236                     }
237                     else
238                     {
239                         valExp = constant->clone();
240                         valExp->getDecorator().setResult(constant->getDecorator().getResult());
241                     }
242                     me = new ast::MemfillExp(loc, *valExp, cloneArgs);
243                     break;
244                 }
245             }
246
247             if (me)
248             {
249                 if (callAtLeft)
250                 {
251                     ast::CallExp & ce = static_cast<ast::CallExp &>(L);
252                     me->getDecorator().setResult(ce.getDecorator().getResult());
253                 }
254                 else
255                 {
256                     ast::CallExp & ce = static_cast<ast::CallExp &>(R);
257                     me->getDecorator().setResult(ce.getDecorator().getResult());
258                 }
259             }
260         }
261     }
262
263     return me;
264 }
265 }