fix [1;[]] after https://codereview.scilab.org/#/c/17221
[scilab.git] / scilab / modules / ast / src / cpp / ast / run_MatrixExp.hpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2008-2008 - DIGITEO - Antoine ELIAS
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 //file included in runvisitor.cpp
14 namespace ast {
15
16 /*
17     [1,2;3,4] with/without special character $ and :
18 */
19 template<class T>
20 void RunVisitorT<T>::visitprivate(const MatrixExp &e)
21 {
22     try
23     {
24         exps_t::const_iterator row;
25         exps_t::const_iterator col;
26         types::InternalType *poResult = NULL;
27         std::list<types::InternalType*> rowList;
28
29         exps_t lines = e.getLines();
30         if (lines.size() == 0)
31         {
32             setResult(types::Double::Empty());
33             return;
34         }
35
36         //special case for 1x1 matrix
37         if (lines.size() == 1)
38         {
39             exps_t cols = lines[0]->getAs<MatrixLineExp>()->getColumns();
40             if (cols.size() == 1)
41             {
42                 setResult(NULL); // Reset value on loop re-start
43
44                 cols[0]->accept(*this);
45                 //manage evstr('//xxx') for example
46                 if (getResult() == NULL)
47                 {
48                     setResult(types::Double::Empty());
49                 }
50                 return;
51             }
52         }
53
54         //do all [x,x]
55         for (row = lines.begin() ; row != lines.end() ; row++)
56         {
57             types::InternalType* poRow = NULL;
58             exps_t cols = (*row)->getAs<MatrixLineExp>()->getColumns();
59             for (col = cols.begin() ; col != cols.end() ; col++)
60             {
61                 setResult(NULL); // Reset value on loop re-start
62
63                 try
64                 {
65                     (*col)->accept(*this);
66                 }
67                 catch (const InternalError& error)
68                 {
69                     if (poRow)
70                     {
71                         poRow->killMe();
72                     }
73                     if (poResult)
74                     {
75                         poResult->killMe();
76                     }
77
78                     throw error;
79                 }
80
81                 types::InternalType *pIT = getResult();
82                 if (pIT == NULL)
83                 {
84                     continue;
85                 }
86
87                 //reset result but without delete the value
88                 clearResultButFirst();
89
90                 if (pIT->isImplicitList())
91                 {
92                     types::ImplicitList *pIL = pIT->getAs<types::ImplicitList>();
93                     if (pIL->isComputable())
94                     {
95                         types::InternalType* pIT2 = pIL->extractFullMatrix();
96                         pIT->killMe();
97                         pIT = pIT2;
98                     }
99                     else
100                     {
101                         if (poRow == NULL)
102                         {
103                             //first loop
104                             poRow = pIT;
105                         }
106                         else
107                         {
108                             try
109                             {
110                                 poRow = callOverloadMatrixExp(L"c", poRow, pIT);
111                             }
112                             catch (const InternalError& error)
113                             {
114                                 if (poResult)
115                                 {
116                                     poResult->killMe();
117                                 }
118                                 throw error;
119                             }
120                         }
121
122                         continue;
123                     }
124                 }
125
126                 if (pIT->isGenericType() == false)
127                 {
128                     pIT->killMe();
129                     std::wostringstream os;
130                     os << _W("unable to concatenate\n");
131                     throw ast::InternalError(os.str(), 999, (*col)->getLocation());
132                 }
133
134                 types::GenericType* pGT = pIT->getAs<types::GenericType>();
135
136                 if (poRow == NULL)
137                 {
138                     //first loop
139                     if (poResult == NULL && pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
140                     {
141                         pGT->killMe();
142                         continue;
143                     }
144                     
145                     poRow = pGT;
146                     continue;
147                 }
148
149                 //manage overload on list/struct/implicitlist and hypermatrix before management of []
150                 if (pGT->isList() || poRow->isList() || pGT->isStruct() || poRow->isStruct() || poRow->isImplicitList() || pGT->getDims() > 2)
151                 {
152                     try
153                     {
154                         poRow = callOverloadMatrixExp(L"c", poRow, pGT);
155                     }
156                     catch (const InternalError& error)
157                     {
158                         if (poResult)
159                         {
160                             poResult->killMe();
161                         }
162                         throw error;
163                     }
164
165                     continue;
166                 }
167
168                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
169                 {
170                     pGT->killMe();
171                     continue;
172                 }
173
174                 types::GenericType* pGTResult = poRow->getAs<types::GenericType>();
175
176                 //check dimension
177                 if (pGT->getDims() != 2 || pGT->getRows() != pGTResult->getRows())
178                 {
179                     poRow->killMe();
180                     if (poRow != pGT)
181                     {
182                         pGT->killMe();
183                     }
184                     std::wostringstream os;
185                     os << _W("inconsistent row/column dimensions\n");
186                     throw ast::InternalError(os.str(), 999, (*row)->getLocation());
187                 }
188
189                 // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
190                 // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
191                 if (pGT->isSparse() && pGTResult->isDouble())
192                 {
193                     poRow = new types::Sparse(*pGTResult->getAs<types::Double>());
194                     pGTResult->killMe();
195                     pGTResult = poRow->getAs<types::GenericType>();
196                 }
197                 else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
198                 {
199                     poRow = new types::SparseBool(*pGTResult->getAs<types::Bool>());
200                     pGTResult->killMe();
201                     pGTResult = poRow->getAs<types::GenericType>();
202                 }
203                 else if (pGT->isDollar() && pGTResult->isDouble())
204                 {
205                     int _iRows = pGTResult->getRows();
206                     int _iCols = pGTResult->getCols();
207                     int* piRank = new int[_iRows * _iCols];
208                     memset(piRank, 0x00, _iRows * _iCols * sizeof(int));
209                     poRow = new types::Polynom(pGT->getAs<types::Polynom>()->getVariableName(), _iRows, _iCols, piRank);
210                     types::Polynom* pP = poRow->getAs<types::Polynom>();
211                     types::SinglePoly** pSS = pP->get();
212                     types::Double* pDb = pGTResult->getAs<types::Double>();
213                     double* pdblR = pDb->get();
214                     if (pDb->isComplex())
215                     {
216                         double* pdblI = pDb->getImg();
217                         pP->setComplex(true);
218                         for (int i = 0; i < pDb->getSize(); i++)
219                         {
220                             pSS[i]->setRank(0);
221                             pSS[i]->setCoef(pdblR + i, pdblI + i);
222                         }
223                     }
224                     else
225                     {
226                         for (int i = 0; i < pDb->getSize(); i++)
227                         {
228                             pSS[i]->setRank(0);
229                             pSS[i]->setCoef(pdblR + i, NULL);
230                         }
231                     }
232
233                     delete[] piRank;
234                 }
235
236                 types::InternalType *pNewSize = AddElementToVariable(NULL, poRow, pGTResult->getRows(), pGTResult->getCols() + pGT->getCols());
237                 types::InternalType* p = AddElementToVariable(pNewSize, pGT, 0, pGTResult->getCols());
238                 if (p != pNewSize)
239                 {
240                     pNewSize->killMe();
241                 }
242                 // call overload
243                 if (p == NULL)
244                 {
245                     try
246                     {
247                         poRow = callOverloadMatrixExp(L"c", pGTResult, pGT);
248                     }
249                     catch (const InternalError& error)
250                     {
251                         if (poResult)
252                         {
253                             poResult->killMe();
254                         }
255                         throw error;
256                     }
257                     continue;
258                 }
259
260                 if (poRow != pGT)
261                 {
262                     pGT->killMe();
263                 }
264
265                 if (p != poRow)
266                 {
267                     poRow->killMe();
268                     poRow = p;
269                 }
270             }
271
272             if (poRow == NULL)
273             {
274                 continue;
275             }
276
277             if (poResult == NULL)
278             {
279                 poResult = poRow;
280                 continue;
281             }
282
283             // management of concatenation with 1:$
284             if (poRow->isImplicitList() || poResult->isImplicitList())
285             {
286                 try
287                 {
288                     poResult = callOverloadMatrixExp(L"f", poResult, poRow);
289                 }
290                 catch (const InternalError& error)
291                 {
292                     throw error;
293                 }
294                 continue;
295             }
296
297             types::GenericType* pGT = poRow->getAs<types::GenericType>();
298
299             //check dimension
300             types::GenericType* pGTResult = poResult->getAs<types::GenericType>();
301
302             if (pGT->isList() || pGTResult->isList() || pGT->isStruct() || pGTResult->isStruct() || pGT->getDims() > 2)
303             {
304                 try
305                 {
306                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
307                 }
308                 catch (const InternalError& error)
309                 {
310                     throw error;
311                 }
312
313                 continue;
314             }
315             else
316             {//[]
317                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
318                 {
319                     pGT->killMe();
320                     continue;
321                 }
322             }
323
324             //check dimension
325             if (pGT->getCols() != pGTResult->getCols())
326             {
327                 poRow->killMe();
328                 if (poResult)
329                 {
330                     poResult->killMe();
331                 }
332                 std::wostringstream os;
333                 os << _W("inconsistent row/column dimensions\n");
334                 throw ast::InternalError(os.str(), 999, (*e.getLines().begin())->getLocation());
335             }
336
337             // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
338             // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
339             if (pGT->isSparse() && pGTResult->isDouble())
340             {
341                 poResult = new types::Sparse(*pGTResult->getAs<types::Double>());
342                 pGTResult->killMe();
343                 pGTResult = poResult->getAs<types::GenericType>();
344             }
345             else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
346             {
347                 poResult = new types::SparseBool(*pGTResult->getAs<types::Bool>());
348                 pGTResult->killMe();
349                 pGTResult = poResult->getAs<types::GenericType>();
350             }
351
352             types::InternalType* pNewSize = AddElementToVariable(NULL, poResult, pGTResult->getRows() + pGT->getRows(), pGT->getCols());
353             types::InternalType* p = AddElementToVariable(pNewSize, pGT, pGTResult->getRows(), 0);
354             if (p != pNewSize)
355             {
356                 pNewSize->killMe();
357             }
358
359             // call overload
360             if (p == NULL)
361             {
362                 try
363                 {
364                    poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
365                 }
366                 catch (const InternalError& error)
367                 {
368                     throw error;
369                 }
370                 continue;
371             }
372
373             if (poResult != poRow)
374             {
375                 poRow->killMe();
376             }
377
378             if (p != poResult)
379             {
380                 poResult->killMe();
381                 poResult = p;
382             }
383         }
384
385         if (poResult)
386         {
387             setResult(poResult);
388         }
389         else
390         {
391             setResult(types::Double::Empty());
392         }
393     }
394     catch (const InternalError& error)
395     {
396         setResult(NULL);
397         throw error;
398     }
399 }
400
401 template<class T>
402 types::InternalType* RunVisitorT<T>::callOverloadMatrixExp(const std::wstring& strType, types::InternalType* _paramL, types::InternalType* _paramR)
403 {
404     types::typed_list in;
405     types::typed_list out;
406     types::Callable::ReturnValue Ret;
407
408     _paramL->IncreaseRef();
409     _paramR->IncreaseRef();
410
411     in.push_back(_paramL);
412     in.push_back(_paramR);
413
414     try
415     {
416         if (_paramR->isGenericType() && _paramR->getAs<types::GenericType>()->getDims() > 2)
417         {
418             Ret = Overload::call(L"%hm_" + strType + L"_hm", in, 1, out, true);
419         }
420         else
421         {
422             Ret = Overload::call(L"%" + _paramL->getAs<types::List>()->getShortTypeStr() + L"_" + strType + L"_" + _paramR->getAs<types::List>()->getShortTypeStr(), in, 1, out, true);
423         }
424     }
425     catch (const InternalError& error)
426     {
427         cleanInOut(in, out);
428         throw error;
429     }
430
431     if (Ret != types::Callable::OK)
432     {
433         cleanInOut(in, out);
434         throw InternalError(ConfigVariable::getLastErrorMessage());
435     }
436
437     cleanIn(in, out);
438
439     if (out.empty())
440     {
441         // TODO: avoid crash if out is empty but must return an error...
442         return NULL;
443     }
444
445     return out[0];
446 }
447
448 } /* namespace ast */