Analysis: always a WIP but we would expect that the end is nearer than the begin
[scilab.git] / scilab / modules / ast / src / cpp / analysis / OperAnalyzer.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2015 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include "AnalysisVisitor.hxx"
14 #include "analyzers/OperAnalyzer.hxx"
15 #include "allexp.hxx"
16 #include "allvar.hxx"
17 #include "alltypes.hxx"
18
19 namespace analysis
20 {
21     bool OperAnalyzer::analyze(AnalysisVisitor & visitor, ast::Exp & e)
22     {
23         ast::OpExp & oe = static_cast<ast::OpExp &>(e);
24         const ast::OpExp::Oper oper = oe.getOper();
25         if (oper == ast::OpExp::plus || oper == ast::OpExp::minus || oper == ast::OpExp::times)
26         {
27             if (ast::MemfillExp * mfe = analyzeMemfill(visitor, oe))
28             {
29                 mfe->setVerbose(e.isVerbose());
30                 e.replace(mfe);
31
32                 return true;
33             }
34         }
35
36         /*if (ast::ExtendedOpExp * eoe = analyzeMemfill(visitor, oe))
37         {
38             eoe->setVerbose(e.isVerbose());
39             e.replace(eoe);
40             
41             return true;
42             }*/
43         
44         return false;
45     }
46
47     /*ast::ExtendedOpExp * OperAnalyzer::analyzeTransposedArgs(ast::OpExp & oe)
48     {
49         ast::Exp & L = oe.getLeft();
50         ast::Exp & R = oe.getRight();
51         ast::ExtendedOpExp::OP Lop, Rop;
52         ast::Exp * Le = &L;
53         ast::Exp * Re = &R;
54
55         if (L.isTransposeExp())
56         {
57             ast::TransposeExp & te = static_cast<ast::TransposeExp &>(L);
58             if (te.getConjugate() == ast::TransposeExp::_Conjugate_)
59             {
60                 Lop = ast::ExtendedOpExp::ADJOINT;
61             }
62             else
63             {
64                 Lop = ast::ExtendedOpExp::TRANSP;
65             }
66             Le = &te.getExp();
67         }
68
69         if (R.isTransposeExp())
70         {
71             ast::TransposeExp & te = static_cast<ast::TransposeExp &>(R);
72             if (te.getConjugate() == ast::TransposeExp::_Conjugate_)
73             {
74                 Rop = ast::ExtendedOpExp::ADJOINT;
75             }
76             else
77             {
78                 Rop = ast::ExtendedOpExp::TRANSP;
79             }
80             Re = &te.getExp();
81         }
82
83         if (Lop != ast::ExtendedOpExp::NONE || Rop != ast::ExtendedOpExp::NONE)
84         {
85             return new ast::ExtendedOpExp(oe.getLocation(), *Le, Lop, oe.getOper(), *Re, Rop);
86         }
87
88         return nullptr;
89         }*/
90     
91     ast::MemfillExp * OperAnalyzer::analyzeMemfill(AnalysisVisitor & visitor, ast::OpExp & oe)
92     {
93         const ast::OpExp::Oper oper = oe.getOper();
94         ast::Exp & L = oe.getLeft();
95         ast::Exp & R = oe.getRight();
96
97         ast::Exp * constant = nullptr;
98         ast::exps_t args;
99         double value;
100         bool callAtLeft;
101
102         // We try to match something like A +* ones(...) or A +* zeros(...)
103         if (L.isCallExp())
104         {
105             ast::CallExp & ce = static_cast<ast::CallExp &>(L);
106             if (ce.getName().isSimpleVar())
107             {
108                 const std::wstring & name = static_cast<ast::SimpleVar &>(ce.getName()).getSymbol().getName();
109                 if (name == L"ones")
110                 {
111                     value = 1;
112                     args = ce.getArgs();
113                     constant = &R;
114                     callAtLeft = true;
115                 }
116                 else if (name == L"zeros")
117                 {
118                     value = 0;
119                     args = ce.getArgs();
120                     constant = &R;
121                     callAtLeft = true;
122                 }
123             }
124         }
125
126         if (!constant)
127         {
128             if (R.isCallExp())
129             {
130                 ast::CallExp & ce = static_cast<ast::CallExp &>(R);
131                 if (ce.getName().isSimpleVar())
132                 {
133                     const std::wstring & name = static_cast<ast::SimpleVar &>(ce.getName()).getSymbol().getName();
134                     if (name == L"ones")
135                     {
136                         value = 1;
137                         args = ce.getArgs();
138                         constant = &L;
139                         callAtLeft = false;
140                     }
141                     else if (name == L"zeros")
142                     {
143                         value = 0;
144                         args = ce.getArgs();
145                         constant = &L;
146                         callAtLeft = false;
147                     }
148                 }
149             }
150         }
151
152         if (constant)
153         {
154             ast::exps_t cloneArgs;
155             cloneArgs.reserve(args.size());
156             for (auto arg : args)
157             {
158                 cloneArgs.push_back(arg->clone());
159             }
160             
161             switch (oper)
162             {
163             case ast::OpExp::plus :
164             {
165                 // plus is commutative so callAtLeft is ignored
166                 Result & res = constant->getDecorator().getResult();
167
168                 if (res.getType().ismatrix() && res.getType().isscalar())
169                 {
170                     // we have something like x + ones(...) => it is a fill with x+1
171                     const Location & loc = oe.getLocation();
172                     double x;
173                     if (res.getConstant().getDblValue(x))
174                     {
175                         return new ast::MemfillExp(loc, *new ast::DoubleExp(loc, new types::Double(x + value)), cloneArgs);
176                     }
177                     else
178                     {
179                         return new ast::MemfillExp(loc, *new ast::OpExp(loc, *constant->clone(), ast::OpExp::plus, *new ast::DoubleExp(loc, new types::Double(value))), cloneArgs);
180                     }
181                 }
182                 break;
183             }
184             case ast::OpExp::minus :
185             {
186                 Result & res = constant->getDecorator().getResult();
187
188                 if (res.getType().ismatrix() && res.getType().isscalar())
189                 {
190                     // we have something like x - ones(...) => it is a fill with x-1
191                     const Location & loc = oe.getLocation();
192                     double x;
193                     if (res.getConstant().getDblValue(x))
194                     {
195                         return new ast::MemfillExp(loc, *new ast::DoubleExp(loc, new types::Double(callAtLeft ? value - x : x - value)), cloneArgs);
196                     }
197                     else if (callAtLeft)
198                     {
199                         return new ast::MemfillExp(loc, *new ast::OpExp(loc, *new ast::DoubleExp(loc, new types::Double(value)), ast::OpExp::minus, *constant->clone()), cloneArgs);
200                     }
201                     else
202                     {
203                         return new ast::MemfillExp(loc, *new ast::OpExp(loc, *constant->clone(), ast::OpExp::minus, *new ast::DoubleExp(loc, new types::Double(value))), cloneArgs);
204                     }
205                 }
206
207                 break;
208             }
209             case ast::OpExp::times :
210             {
211                 // times is commutative so callAtLeft is ignored
212                 Result & res = constant->getDecorator().getResult();
213
214                 if (res.getType().ismatrix() && res.getType().isscalar())
215                 {
216                     // we have something like x + ones(...) => it is a fill with x+1
217                     const Location & loc = oe.getLocation();
218                     double x;
219                     if (res.getConstant().getDblValue(x))
220                     {
221                         return new ast::MemfillExp(loc, *new ast::DoubleExp(loc, new types::Double(x * value)), cloneArgs);
222                     }
223                     else if (value == 0)
224                     {
225                         return new ast::MemfillExp(loc, *new ast::DoubleExp(loc, new types::Double(0)), cloneArgs);
226                     }
227                     else
228                     {
229                         return new ast::MemfillExp(loc, *constant->clone(), cloneArgs);
230                     }
231                 }
232                 break;
233             }
234             }
235         }
236
237         return nullptr;
238     }
239 }