Merge remote-tracking branch 'origin/master' into jit
[scilab.git] / scilab / modules / ast / src / cpp / ast / run_MatrixExp.hpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2008-2008 - DIGITEO - Antoine ELIAS
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 //file included in runvisitor.cpp
14 namespace ast {
15
16 /*
17     [1,2;3,4] with/without special character $ and :
18     */
19 template<class T>
20 void RunVisitorT<T>::visitprivate(const MatrixExp &e)
21 {
22     CoverageInstance::invokeAndStartChrono((void*)&e);
23     try
24     {
25         exps_t::const_iterator row;
26         exps_t::const_iterator col;
27         types::InternalType *poResult = NULL;
28         std::list<types::InternalType*> rowList;
29
30         exps_t lines = e.getLines();
31         if (lines.size() == 0)
32         {
33             setResult(types::Double::Empty());
34             CoverageInstance::invokeAndStartChrono((void*)&e);
35             return;
36         }
37
38         //special case for 1x1 matrix
39         if (lines.size() == 1)
40         {
41             exps_t cols = lines[0]->getAs<MatrixLineExp>()->getColumns();
42             if (cols.size() == 1)
43             {
44                 setResult(NULL); // Reset value on loop re-start
45
46                 cols[0]->accept(*this);
47                 //manage evstr('//xxx') for example
48                 if (getResult() == NULL)
49                 {
50                     setResult(types::Double::Empty());
51                 }
52                 CoverageInstance::invokeAndStartChrono((void*)&e);
53                 return;
54             }
55         }
56
57         //do all [x,x]
58         for (row = lines.begin(); row != lines.end(); row++)
59         {
60             types::InternalType* poRow = NULL;
61             exps_t cols = (*row)->getAs<MatrixLineExp>()->getColumns();
62             for (col = cols.begin(); col != cols.end(); col++)
63             {
64                 setResult(NULL); // Reset value on loop re-start
65
66                 try
67                 {
68                     (*col)->accept(*this);
69                 }
70                 catch (const InternalError& error)
71                 {
72                     if (poRow)
73                     {
74                         poRow->killMe();
75                     }
76                     if (poResult)
77                     {
78                         poResult->killMe();
79                     }
80
81                     throw error;
82                 }
83
84                 types::InternalType *pIT = getResult();
85                 if (pIT == NULL)
86                 {
87                     continue;
88                 }
89
90                 //reset result but without delete the value
91                 clearResultButFirst();
92
93                 if (pIT->isImplicitList())
94                 {
95                     types::ImplicitList *pIL = pIT->getAs<types::ImplicitList>();
96                     if (pIL->isComputable())
97                     {
98                         types::InternalType* pIT2 = pIL->extractFullMatrix();
99                         pIT->killMe();
100                         pIT = pIT2;
101                     }
102                     else
103                     {
104                         if (poRow == NULL)
105                         {
106                             //first loop
107                             poRow = pIT;
108                         }
109                         else
110                         {
111                             try
112                             {
113                                 poRow = callOverloadMatrixExp(L"c", poRow, pIT);
114                             }
115                             catch (const InternalError& error)
116                             {
117                                 if (poResult)
118                                 {
119                                     poResult->killMe();
120                                 }
121                                 throw error;
122                             }
123                         }
124
125                         continue;
126                     }
127                 }
128
129                 if (pIT->isGenericType() == false)
130                 {
131                     pIT->killMe();
132                     std::wostringstream os;
133                     os << _W("unable to concatenate\n");
134                     throw ast::InternalError(os.str(), 999, (*col)->getLocation());
135                 }
136
137                 types::GenericType* pGT = pIT->getAs<types::GenericType>();
138
139                 if (poRow == NULL)
140                 {
141                     //first loop
142                     if (poResult == NULL && pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
143                     {
144                         pGT->killMe();
145                         continue;
146                     }
147
148                     if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
149                     {
150                         if (poResult && (poResult->isList() || poResult->isStruct()))
151                         {
152                             //in case of [list(); [], ...]
153
154                             //we don't know what to do with [], keep it as "normal" value and continue process
155                             poRow = pGT;
156                             continue;
157                         }
158
159                         pGT->killMe();
160                         continue;
161                     }
162
163                     poRow = pGT;
164                     continue;
165                 }
166
167                 //manage overload on list/struct/implicitlist and hypermatrix before management of []
168                 if (pGT->isList() || poRow->isList() || pGT->isStruct() || poRow->isStruct() || poRow->isImplicitList() || pGT->getDims() > 2)
169                 {
170                     try
171                     {
172                         poRow = callOverloadMatrixExp(L"c", poRow, pGT);
173                     }
174                     catch (const InternalError& error)
175                     {
176                         if (poResult)
177                         {
178                             poResult->killMe();
179                         }
180                         throw error;
181                     }
182
183                     continue;
184                 }
185
186                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
187                 {
188                     pGT->killMe();
189                     continue;
190                 }
191
192                 types::GenericType* pGTResult = poRow->getAs<types::GenericType>();
193
194                 //check dimension
195                 if (pGT->getDims() != 2 || pGT->getRows() != pGTResult->getRows())
196                 {
197                     poRow->killMe();
198                     if (poRow != pGT)
199                     {
200                         pGT->killMe();
201                     }
202                     std::wostringstream os;
203                     os << _W("inconsistent row/column dimensions\n");
204                     throw ast::InternalError(os.str(), 999, (*row)->getLocation());
205                 }
206
207                 // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
208                 // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
209                 if (pGT->isSparse() && pGTResult->isDouble())
210                 {
211                     poRow = new types::Sparse(*pGTResult->getAs<types::Double>());
212                     pGTResult->killMe();
213                     pGTResult = poRow->getAs<types::GenericType>();
214                 }
215                 else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
216                 {
217                     poRow = new types::SparseBool(*pGTResult->getAs<types::Bool>());
218                     pGTResult->killMe();
219                     pGTResult = poRow->getAs<types::GenericType>();
220                 }
221                 else if (pGT->isDollar() && pGTResult->isDouble())
222                 {
223                     int _iRows = pGTResult->getRows();
224                     int _iCols = pGTResult->getCols();
225                     int* piRank = new int[_iRows * _iCols];
226                     memset(piRank, 0x00, _iRows * _iCols * sizeof(int));
227                     poRow = new types::Polynom(pGT->getAs<types::Polynom>()->getVariableName(), _iRows, _iCols, piRank);
228                     types::Polynom* pP = poRow->getAs<types::Polynom>();
229                     types::SinglePoly** pSS = pP->get();
230                     types::Double* pDb = pGTResult->getAs<types::Double>();
231                     double* pdblR = pDb->get();
232                     if (pDb->isComplex())
233                     {
234                         double* pdblI = pDb->getImg();
235                         pP->setComplex(true);
236                         for (int i = 0; i < pDb->getSize(); i++)
237                         {
238                             pSS[i]->setRank(0);
239                             pSS[i]->setCoef(pdblR + i, pdblI + i);
240                         }
241                     }
242                     else
243                     {
244                         for (int i = 0; i < pDb->getSize(); i++)
245                         {
246                             pSS[i]->setRank(0);
247                             pSS[i]->setCoef(pdblR + i, NULL);
248                         }
249                     }
250
251                     delete[] piRank;
252                 }
253
254                 types::InternalType *pNewSize = AddElementToVariable(NULL, poRow, pGTResult->getRows(), pGTResult->getCols() + pGT->getCols());
255                 types::InternalType* p = AddElementToVariable(pNewSize, pGT, 0, pGTResult->getCols());
256                 if (p != pNewSize)
257                 {
258                     pNewSize->killMe();
259                 }
260                 // call overload
261                 if (p == NULL)
262                 {
263                     try
264                     {
265                         poRow = callOverloadMatrixExp(L"c", pGTResult, pGT);
266                     }
267                     catch (const InternalError& error)
268                     {
269                         if (poResult)
270                         {
271                             poResult->killMe();
272                         }
273                         throw error;
274                     }
275                     continue;
276                 }
277
278                 if (poRow != pGT)
279                 {
280                     pGT->killMe();
281                 }
282
283                 if (p != poRow)
284                 {
285                     poRow->killMe();
286                     poRow = p;
287                 }
288             }
289
290             if (poRow == NULL)
291             {
292                 continue;
293             }
294
295             if (poResult == NULL)
296             {
297                 poResult = poRow;
298                 continue;
299             }
300
301             // management of concatenation with 1:$
302             if (poRow->isImplicitList() || poResult->isImplicitList())
303             {
304                 try
305                 {
306                     poResult = callOverloadMatrixExp(L"f", poResult, poRow);
307                 }
308                 catch (const InternalError& error)
309                 {
310                     throw error;
311                 }
312                 continue;
313             }
314
315             types::GenericType* pGT = poRow->getAs<types::GenericType>();
316
317             //check dimension
318             types::GenericType* pGTResult = poResult->getAs<types::GenericType>();
319
320             if (pGT->isList() || pGTResult->isList() || pGT->isStruct() || pGTResult->isStruct() || pGT->getDims() > 2)
321             {
322                 try
323                 {
324                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
325                 }
326                 catch (const InternalError& error)
327                 {
328                     throw error;
329                 }
330
331                 continue;
332             }
333             else
334             {//[]
335                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
336                 {
337                     pGT->killMe();
338                     continue;
339                 }
340             }
341
342             //check dimension
343             if (pGT->getCols() != pGTResult->getCols())
344             {
345                 poRow->killMe();
346                 if (poResult)
347                 {
348                     poResult->killMe();
349                 }
350                 std::wostringstream os;
351                 os << _W("inconsistent row/column dimensions\n");
352                 throw ast::InternalError(os.str(), 999, (*e.getLines().begin())->getLocation());
353             }
354
355             // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
356             // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
357             if (pGT->isSparse() && pGTResult->isDouble())
358             {
359                 poResult = new types::Sparse(*pGTResult->getAs<types::Double>());
360                 pGTResult->killMe();
361                 pGTResult = poResult->getAs<types::GenericType>();
362             }
363             else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
364             {
365                 poResult = new types::SparseBool(*pGTResult->getAs<types::Bool>());
366                 pGTResult->killMe();
367                 pGTResult = poResult->getAs<types::GenericType>();
368             }
369
370             types::InternalType* pNewSize = AddElementToVariable(NULL, poResult, pGTResult->getRows() + pGT->getRows(), pGT->getCols());
371             types::InternalType* p = AddElementToVariable(pNewSize, pGT, pGTResult->getRows(), 0);
372             if (p != pNewSize)
373             {
374                 pNewSize->killMe();
375             }
376
377             // call overload
378             if (p == NULL)
379             {
380                 try
381                 {
382                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
383                 }
384                 catch (const InternalError& error)
385                 {
386                     throw error;
387                 }
388                 continue;
389             }
390
391             if (poResult != poRow)
392             {
393                 poRow->killMe();
394             }
395
396             if (p != poResult)
397             {
398                 poResult->killMe();
399                 poResult = p;
400             }
401         }
402
403         if (poResult)
404         {
405             setResult(poResult);
406         }
407         else
408         {
409             setResult(types::Double::Empty());
410         }
411     }
412     catch (const InternalError& error)
413     {
414         setResult(NULL);
415         CoverageInstance::invokeAndStartChrono((void*)&e);
416         throw error;
417     }
418     CoverageInstance::invokeAndStartChrono((void*)&e);
419 }
420
421 template<class T>
422 types::InternalType* RunVisitorT<T>::callOverloadMatrixExp(const std::wstring& strType, types::InternalType* _paramL, types::InternalType* _paramR)
423 {
424     types::typed_list in;
425     types::typed_list out;
426     types::Callable::ReturnValue Ret;
427
428     _paramL->IncreaseRef();
429     _paramR->IncreaseRef();
430
431     in.push_back(_paramL);
432     in.push_back(_paramR);
433
434     try
435     {
436         if (_paramR->isGenericType() && _paramR->getAs<types::GenericType>()->getDims() > 2)
437         {
438             Ret = Overload::call(L"%hm_" + strType + L"_hm", in, 1, out, true);
439         }
440         else
441         {
442             Ret = Overload::call(L"%" + _paramL->getAs<types::List>()->getShortTypeStr() + L"_" + strType + L"_" + _paramR->getAs<types::List>()->getShortTypeStr(), in, 1, out, true);
443         }
444     }
445     catch (const InternalError& error)
446     {
447         cleanInOut(in, out);
448         throw error;
449     }
450
451     if (Ret != types::Callable::OK)
452     {
453         cleanInOut(in, out);
454         throw InternalError(ConfigVariable::getLastErrorMessage());
455     }
456
457     cleanIn(in, out);
458
459     if (out.empty())
460     {
461         // TODO: avoid crash if out is empty but must return an error...
462         return NULL;
463     }
464
465     return out[0];
466 }
467
468 } /* namespace ast */