[bug_9571] concatenation fixed
[scilab.git] / scilab / modules / ast / src / cpp / ast / run_MatrixExp.hpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2008-2008 - DIGITEO - Antoine ELIAS
4  *
5  * Copyright (C) 2012 - 2016 - Scilab Enterprises
6  *
7  * This file is hereby licensed under the terms of the GNU GPL v2.0,
8  * pursuant to article 5.3.4 of the CeCILL v.2.1.
9  * This file was originally licensed under the terms of the CeCILL v2.1,
10  * and continues to be available under such terms.
11  * For more information, see the COPYING file which you should have received
12  * along with this program.
13  *
14  */
15
16 //file included in runvisitor.cpp
17 namespace ast {
18
19 /*
20     [1,2;3,4] with/without special character $ and :
21     */
22 template<class T>
23 void RunVisitorT<T>::visitprivate(const MatrixExp &e)
24 {
25     CoverageInstance::invokeAndStartChrono((void*)&e);
26     try
27     {
28         exps_t::const_iterator row;
29         exps_t::const_iterator col;
30         types::InternalType *poResult = NULL;
31         std::list<types::InternalType*> rowList;
32
33         exps_t lines = e.getLines();
34         if (lines.size() == 0)
35         {
36             setResult(types::Double::Empty());
37             CoverageInstance::invokeAndStartChrono((void*)&e);
38             return;
39         }
40
41         //special case for 1x1 matrix
42         if (lines.size() == 1)
43         {
44             exps_t cols = lines[0]->getAs<MatrixLineExp>()->getColumns();
45             if (cols.size() == 1)
46             {
47                 setResult(NULL); // Reset value on loop re-start
48
49                 cols[0]->accept(*this);
50                 //manage evstr('//xxx') for example
51                 if (getResult() == NULL)
52                 {
53                     setResult(types::Double::Empty());
54                 }
55                 CoverageInstance::invokeAndStartChrono((void*)&e);
56                 return;
57             }
58         }
59
60         //do all [x,x]
61         for (row = lines.begin(); row != lines.end(); row++)
62         {
63             types::InternalType* poRow = NULL;
64             exps_t cols = (*row)->getAs<MatrixLineExp>()->getColumns();
65             for (col = cols.begin(); col != cols.end(); col++)
66             {
67                 setResult(NULL); // Reset value on loop re-start
68
69                 try
70                 {
71                     (*col)->accept(*this);
72                 }
73                 catch (const InternalError& error)
74                 {
75                     if (poRow)
76                     {
77                         poRow->killMe();
78                     }
79                     if (poResult)
80                     {
81                         poResult->killMe();
82                     }
83
84                     throw error;
85                 }
86
87                 types::InternalType *pIT = getResult();
88                 if (pIT == NULL)
89                 {
90                     continue;
91                 }
92
93                 //reset result but without delete the value
94                 clearResultButFirst();
95
96                 if (pIT->isImplicitList())
97                 {
98                     types::ImplicitList *pIL = pIT->getAs<types::ImplicitList>();
99                     if (pIL->isComputable())
100                     {
101                         types::InternalType* pIT2 = pIL->extractFullMatrix();
102                         pIT->killMe();
103                         pIT = pIT2;
104                     }
105                     else
106                     {
107                         if (poRow == NULL)
108                         {
109                             //first loop
110                             poRow = pIT;
111                         }
112                         else
113                         {
114                             try
115                             {
116                                 poRow = callOverloadMatrixExp(L"c", poRow, pIT);
117                             }
118                             catch (const InternalError& error)
119                             {
120                                 if (poResult)
121                                 {
122                                     poResult->killMe();
123                                 }
124                                 throw error;
125                             }
126                         }
127
128                         continue;
129                     }
130                 }
131
132                 if (pIT->isGenericType() == false)
133                 {
134                     if (poRow == NULL)
135                     {
136                         //first loop
137                         poRow = pIT;
138                     }
139                     else
140                     {
141                         try
142                         {
143                             poRow = callOverloadMatrixExp(L"c", poRow, pIT);
144                         }
145                         catch (const InternalError& error)
146                         {
147                             if (poResult)
148                             {
149                                 poResult->killMe();
150                             }
151
152                             pIT->killMe();
153                             throw error;
154                         }
155                     }
156
157                     continue;
158                 }
159
160                 types::GenericType* pGT = pIT->getAs<types::GenericType>();
161
162                 if (poRow == NULL)
163                 {
164                     //first loop
165                     if (poResult == NULL && pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
166                     {
167                         pGT->killMe();
168                         continue;
169                     }
170
171                     if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
172                     {
173                         if (poResult && (poResult->isList() || poResult->isStruct()))
174                         {
175                             //in case of [list(); [], ...]
176
177                             //we don't know what to do with [], keep it as "normal" value and continue process
178                             poRow = pGT;
179                             continue;
180                         }
181
182                         pGT->killMe();
183                         continue;
184                     }
185
186                     poRow = pGT;
187                     continue;
188                 }
189
190                 //manage overload on list/struct/implicitlist and hypermatrix before management of []
191                 if (pGT->isList() || poRow->isList() || pGT->isStruct() || poRow->isStruct() || poRow->isImplicitList() || pGT->getDims() > 2)
192                 {
193                     try
194                     {
195                         poRow = callOverloadMatrixExp(L"c", poRow, pGT);
196                     }
197                     catch (const InternalError& error)
198                     {
199                         if (poResult)
200                         {
201                             poResult->killMe();
202                         }
203                         throw error;
204                     }
205
206                     continue;
207                 }
208
209                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
210                 {
211                     pGT->killMe();
212                     continue;
213                 }
214
215                 types::GenericType* pGTResult = poRow->getAs<types::GenericType>();
216
217                 //check dimension
218                 if (pGT->getDims() != 2 || pGT->getRows() != pGTResult->getRows())
219                 {
220                     poRow->killMe();
221                     if (poRow != pGT)
222                     {
223                         pGT->killMe();
224                     }
225                     std::wostringstream os;
226                     os << _W("inconsistent row/column dimensions\n");
227                     throw ast::InternalError(os.str(), 999, (*row)->getLocation());
228                 }
229
230                 // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
231                 // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
232                 if (pGT->isSparse() && pGTResult->isDouble())
233                 {
234                     poRow = new types::Sparse(*pGTResult->getAs<types::Double>());
235                     pGTResult->killMe();
236                     pGTResult = poRow->getAs<types::GenericType>();
237                 }
238                 else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
239                 {
240                     poRow = new types::SparseBool(*pGTResult->getAs<types::Bool>());
241                     pGTResult->killMe();
242                     pGTResult = poRow->getAs<types::GenericType>();
243                 }
244                 else if (pGT->isDollar() && pGTResult->isDouble())
245                 {
246                     int _iRows = pGTResult->getRows();
247                     int _iCols = pGTResult->getCols();
248                     int* piRank = new int[_iRows * _iCols];
249                     memset(piRank, 0x00, _iRows * _iCols * sizeof(int));
250                     poRow = new types::Polynom(pGT->getAs<types::Polynom>()->getVariableName(), _iRows, _iCols, piRank);
251                     types::Polynom* pP = poRow->getAs<types::Polynom>();
252                     types::SinglePoly** pSS = pP->get();
253                     types::Double* pDb = pGTResult->getAs<types::Double>();
254                     double* pdblR = pDb->get();
255                     if (pDb->isComplex())
256                     {
257                         double* pdblI = pDb->getImg();
258                         pP->setComplex(true);
259                         for (int i = 0; i < pDb->getSize(); i++)
260                         {
261                             pSS[i]->setRank(0);
262                             pSS[i]->setCoef(pdblR + i, pdblI + i);
263                         }
264                     }
265                     else
266                     {
267                         for (int i = 0; i < pDb->getSize(); i++)
268                         {
269                             pSS[i]->setRank(0);
270                             pSS[i]->setCoef(pdblR + i, NULL);
271                         }
272                     }
273
274                     delete[] piRank;
275                 }
276
277                 types::InternalType *pNewSize = AddElementToVariable(NULL, poRow, pGTResult->getRows(), pGTResult->getCols() + pGT->getCols());
278                 types::InternalType* p = AddElementToVariable(pNewSize, pGT, 0, pGTResult->getCols());
279                 if (p != pNewSize)
280                 {
281                     pNewSize->killMe();
282                 }
283                 // call overload
284                 if (p == NULL)
285                 {
286                     try
287                     {
288                         poRow = callOverloadMatrixExp(L"c", pGTResult, pGT);
289                     }
290                     catch (const InternalError& error)
291                     {
292                         if (poResult)
293                         {
294                             poResult->killMe();
295                         }
296                         throw error;
297                     }
298                     continue;
299                 }
300
301                 if (poRow != pGT)
302                 {
303                     pGT->killMe();
304                 }
305
306                 if (p != poRow)
307                 {
308                     poRow->killMe();
309                     poRow = p;
310                 }
311             }
312
313             if (poRow == NULL)
314             {
315                 continue;
316             }
317
318             if (poResult == NULL)
319             {
320                 poResult = poRow;
321                 continue;
322             }
323
324             // management of concatenation with 1:$
325             if (poRow->isImplicitList() || poResult->isImplicitList())
326             {
327                 try
328                 {
329                     poResult = callOverloadMatrixExp(L"f", poResult, poRow);
330                 }
331                 catch (const InternalError& error)
332                 {
333                     throw error;
334                 }
335                 continue;
336             }
337
338             types::GenericType* pGT = poRow->getAs<types::GenericType>();
339
340             //check dimension
341             types::GenericType* pGTResult = poResult->getAs<types::GenericType>();
342
343             if (pGT->isList() || pGTResult->isList() || pGT->isStruct() || pGTResult->isStruct() || pGT->getDims() > 2)
344             {
345                 try
346                 {
347                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
348                 }
349                 catch (const InternalError& error)
350                 {
351                     throw error;
352                 }
353
354                 continue;
355             }
356             else
357             {//[]
358                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
359                 {
360                     pGT->killMe();
361                     continue;
362                 }
363             }
364
365             //check dimension
366             if (pGT->getCols() != pGTResult->getCols())
367             {
368                 poRow->killMe();
369                 if (poResult)
370                 {
371                     poResult->killMe();
372                 }
373                 std::wostringstream os;
374                 os << _W("inconsistent row/column dimensions\n");
375                 throw ast::InternalError(os.str(), 999, (*e.getLines().begin())->getLocation());
376             }
377
378             // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
379             // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
380             if (pGT->isSparse() && pGTResult->isDouble())
381             {
382                 poResult = new types::Sparse(*pGTResult->getAs<types::Double>());
383                 pGTResult->killMe();
384                 pGTResult = poResult->getAs<types::GenericType>();
385             }
386             else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
387             {
388                 poResult = new types::SparseBool(*pGTResult->getAs<types::Bool>());
389                 pGTResult->killMe();
390                 pGTResult = poResult->getAs<types::GenericType>();
391             }
392
393             types::InternalType* pNewSize = AddElementToVariable(NULL, poResult, pGTResult->getRows() + pGT->getRows(), pGT->getCols());
394             types::InternalType* p = AddElementToVariable(pNewSize, pGT, pGTResult->getRows(), 0);
395             if (p != pNewSize)
396             {
397                 pNewSize->killMe();
398             }
399
400             // call overload
401             if (p == NULL)
402             {
403                 try
404                 {
405                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
406                 }
407                 catch (const InternalError& error)
408                 {
409                     throw error;
410                 }
411                 continue;
412             }
413
414             if (poResult != poRow)
415             {
416                 poRow->killMe();
417             }
418
419             if (p != poResult)
420             {
421                 poResult->killMe();
422                 poResult = p;
423             }
424         }
425
426         if (poResult)
427         {
428             setResult(poResult);
429         }
430         else
431         {
432             setResult(types::Double::Empty());
433         }
434     }
435     catch (const InternalError& error)
436     {
437         setResult(NULL);
438         CoverageInstance::invokeAndStartChrono((void*)&e);
439         throw error;
440     }
441     CoverageInstance::invokeAndStartChrono((void*)&e);
442 }
443
444 template<class T>
445 types::InternalType* RunVisitorT<T>::callOverloadMatrixExp(const std::wstring& strType, types::InternalType* _paramL, types::InternalType* _paramR)
446 {
447     types::typed_list in;
448     types::typed_list out;
449     types::Callable::ReturnValue Ret;
450
451     _paramL->IncreaseRef();
452     _paramR->IncreaseRef();
453
454     in.push_back(_paramL);
455     in.push_back(_paramR);
456
457     try
458     {
459         if (_paramR->isGenericType() && _paramR->getAs<types::GenericType>()->getDims() > 2)
460         {
461             Ret = Overload::call(L"%hm_" + strType + L"_hm", in, 1, out, true);
462         }
463         else
464         {
465             Ret = Overload::call(L"%" + _paramL->getAs<types::List>()->getShortTypeStr() + L"_" + strType + L"_" + _paramR->getAs<types::List>()->getShortTypeStr(), in, 1, out, true);
466         }
467     }
468     catch (const InternalError& error)
469     {
470         cleanInOut(in, out);
471         throw error;
472     }
473
474     if (Ret != types::Callable::OK)
475     {
476         cleanInOut(in, out);
477         throw InternalError(ConfigVariable::getLastErrorMessage());
478     }
479
480     cleanIn(in, out);
481
482     if (out.empty())
483     {
484         // TODO: avoid crash if out is empty but must return an error...
485         return NULL;
486     }
487
488     return out[0];
489 }
490
491 } /* namespace ast */