fix concatenation of matrix with [] and list()
[scilab.git] / scilab / modules / ast / src / cpp / ast / run_MatrixExp.hpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2008-2008 - DIGITEO - Antoine ELIAS
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 //file included in runvisitor.cpp
14 namespace ast {
15
16 /*
17     [1,2;3,4] with/without special character $ and :
18 */
19 template<class T>
20 void RunVisitorT<T>::visitprivate(const MatrixExp &e)
21 {
22     try
23     {
24         exps_t::const_iterator row;
25         exps_t::const_iterator col;
26         types::InternalType *poResult = NULL;
27         std::list<types::InternalType*> rowList;
28
29         exps_t lines = e.getLines();
30         if (lines.size() == 0)
31         {
32             setResult(types::Double::Empty());
33             return;
34         }
35
36         //special case for 1x1 matrix
37         if (lines.size() == 1)
38         {
39             exps_t cols = lines[0]->getAs<MatrixLineExp>()->getColumns();
40             if (cols.size() == 1)
41             {
42                 setResult(NULL); // Reset value on loop re-start
43
44                 cols[0]->accept(*this);
45                 //manage evstr('//xxx') for example
46                 if (getResult() == NULL)
47                 {
48                     setResult(types::Double::Empty());
49                 }
50                 return;
51             }
52         }
53
54         //do all [x,x]
55         for (row = lines.begin() ; row != lines.end() ; row++)
56         {
57             types::InternalType* poRow = NULL;
58             exps_t cols = (*row)->getAs<MatrixLineExp>()->getColumns();
59             for (col = cols.begin() ; col != cols.end() ; col++)
60             {
61                 setResult(NULL); // Reset value on loop re-start
62
63                 try
64                 {
65                     (*col)->accept(*this);
66                 }
67                 catch (const InternalError& error)
68                 {
69                     if (poRow)
70                     {
71                         poRow->killMe();
72                     }
73                     if (poResult)
74                     {
75                         poResult->killMe();
76                     }
77
78                     throw error;
79                 }
80
81                 types::InternalType *pIT = getResult();
82                 if (pIT == NULL)
83                 {
84                     continue;
85                 }
86
87                 //reset result but without delete the value
88                 clearResultButFirst();
89
90                 if (pIT->isImplicitList())
91                 {
92                     types::ImplicitList *pIL = pIT->getAs<types::ImplicitList>();
93                     if (pIL->isComputable())
94                     {
95                         types::InternalType* pIT2 = pIL->extractFullMatrix();
96                         pIT->killMe();
97                         pIT = pIT2;
98                     }
99                     else
100                     {
101                         if (poRow == NULL)
102                         {
103                             //first loop
104                             poRow = pIT;
105                         }
106                         else
107                         {
108                             try
109                             {
110                                 poRow = callOverloadMatrixExp(L"c", poRow, pIT);
111                             }
112                             catch (const InternalError& error)
113                             {
114                                 if (poResult)
115                                 {
116                                     poResult->killMe();
117                                 }
118                                 throw error;
119                             }
120                         }
121
122                         continue;
123                     }
124                 }
125
126                 if (pIT->isGenericType() == false)
127                 {
128                     pIT->killMe();
129                     std::wostringstream os;
130                     os << _W("unable to concatenate\n");
131                     throw ast::InternalError(os.str(), 999, (*col)->getLocation());
132                 }
133
134                 types::GenericType* pGT = pIT->getAs<types::GenericType>();
135
136                 if (poRow == NULL)
137                 {
138                     //first loop
139                     if (poResult == NULL && pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
140                     {
141                         pGT->killMe();
142                         continue;
143                     }
144
145                     if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
146                     {
147                         pGT->killMe();
148                         continue;
149                     }
150
151                     poRow = pGT;
152                     continue;
153                 }
154
155                 //manage overload on list/struct/implicitlist and hypermatrix before management of []
156                 if (pGT->isList() || poRow->isList() || pGT->isStruct() || poRow->isStruct() || poRow->isImplicitList() || pGT->getDims() > 2)
157                 {
158                     try
159                     {
160                         poRow = callOverloadMatrixExp(L"c", poRow, pGT);
161                     }
162                     catch (const InternalError& error)
163                     {
164                         if (poResult)
165                         {
166                             poResult->killMe();
167                         }
168                         throw error;
169                     }
170
171                     continue;
172                 }
173
174                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
175                 {
176                     pGT->killMe();
177                     continue;
178                 }
179
180                 types::GenericType* pGTResult = poRow->getAs<types::GenericType>();
181
182                 //check dimension
183                 if (pGT->getDims() != 2 || pGT->getRows() != pGTResult->getRows())
184                 {
185                     poRow->killMe();
186                     if (poRow != pGT)
187                     {
188                         pGT->killMe();
189                     }
190                     std::wostringstream os;
191                     os << _W("inconsistent row/column dimensions\n");
192                     throw ast::InternalError(os.str(), 999, (*row)->getLocation());
193                 }
194
195                 // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
196                 // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
197                 if (pGT->isSparse() && pGTResult->isDouble())
198                 {
199                     poRow = new types::Sparse(*pGTResult->getAs<types::Double>());
200                     pGTResult->killMe();
201                     pGTResult = poRow->getAs<types::GenericType>();
202                 }
203                 else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
204                 {
205                     poRow = new types::SparseBool(*pGTResult->getAs<types::Bool>());
206                     pGTResult->killMe();
207                     pGTResult = poRow->getAs<types::GenericType>();
208                 }
209                 else if (pGT->isDollar() && pGTResult->isDouble())
210                 {
211                     int _iRows = pGTResult->getRows();
212                     int _iCols = pGTResult->getCols();
213                     int* piRank = new int[_iRows * _iCols];
214                     memset(piRank, 0x00, _iRows * _iCols * sizeof(int));
215                     poRow = new types::Polynom(pGT->getAs<types::Polynom>()->getVariableName(), _iRows, _iCols, piRank);
216                     types::Polynom* pP = poRow->getAs<types::Polynom>();
217                     types::SinglePoly** pSS = pP->get();
218                     types::Double* pDb = pGTResult->getAs<types::Double>();
219                     double* pdblR = pDb->get();
220                     if (pDb->isComplex())
221                     {
222                         double* pdblI = pDb->getImg();
223                         pP->setComplex(true);
224                         for (int i = 0; i < pDb->getSize(); i++)
225                         {
226                             pSS[i]->setRank(0);
227                             pSS[i]->setCoef(pdblR + i, pdblI + i);
228                         }
229                     }
230                     else
231                     {
232                         for (int i = 0; i < pDb->getSize(); i++)
233                         {
234                             pSS[i]->setRank(0);
235                             pSS[i]->setCoef(pdblR + i, NULL);
236                         }
237                     }
238
239                     delete[] piRank;
240                 }
241
242                 types::InternalType *pNewSize = AddElementToVariable(NULL, poRow, pGTResult->getRows(), pGTResult->getCols() + pGT->getCols());
243                 types::InternalType* p = AddElementToVariable(pNewSize, pGT, 0, pGTResult->getCols());
244                 if (p != pNewSize)
245                 {
246                     pNewSize->killMe();
247                 }
248                 // call overload
249                 if (p == NULL)
250                 {
251                     try
252                     {
253                         poRow = callOverloadMatrixExp(L"c", pGTResult, pGT);
254                     }
255                     catch (const InternalError& error)
256                     {
257                         if (poResult)
258                         {
259                             poResult->killMe();
260                         }
261                         throw error;
262                     }
263                     continue;
264                 }
265
266                 if (poRow != pGT)
267                 {
268                     pGT->killMe();
269                 }
270
271                 if (p != poRow)
272                 {
273                     poRow->killMe();
274                     poRow = p;
275                 }
276             }
277
278             if (poRow == NULL)
279             {
280                 continue;
281             }
282
283             if (poResult == NULL)
284             {
285                 poResult = poRow;
286                 continue;
287             }
288
289             // management of concatenation with 1:$
290             if (poRow->isImplicitList() || poResult->isImplicitList())
291             {
292                 try
293                 {
294                     poResult = callOverloadMatrixExp(L"f", poResult, poRow);
295                 }
296                 catch (const InternalError& error)
297                 {
298                     throw error;
299                 }
300                 continue;
301             }
302
303             types::GenericType* pGT = poRow->getAs<types::GenericType>();
304
305             //check dimension
306             types::GenericType* pGTResult = poResult->getAs<types::GenericType>();
307
308             if (pGT->isList() || pGTResult->isList() || pGT->isStruct() || pGTResult->isStruct() || pGT->getDims() > 2)
309             {
310                 try
311                 {
312                     poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
313                 }
314                 catch (const InternalError& error)
315                 {
316                     throw error;
317                 }
318
319                 continue;
320             }
321             else
322             {//[]
323                 if (pGT->isDouble() && pGT->getAs<types::Double>()->isEmpty())
324                 {
325                     pGT->killMe();
326                     continue;
327                 }
328             }
329
330             //check dimension
331             if (pGT->getCols() != pGTResult->getCols())
332             {
333                 poRow->killMe();
334                 if (poResult)
335                 {
336                     poResult->killMe();
337                 }
338                 std::wostringstream os;
339                 os << _W("inconsistent row/column dimensions\n");
340                 throw ast::InternalError(os.str(), 999, (*e.getLines().begin())->getLocation());
341             }
342
343             // if we concatenate [Double Sparse], transform the Double to Sparse and perform [Sparse Sparse]
344             // this avoids to allocate a Double result of size of Double+Sparse and initialize all elements.
345             if (pGT->isSparse() && pGTResult->isDouble())
346             {
347                 poResult = new types::Sparse(*pGTResult->getAs<types::Double>());
348                 pGTResult->killMe();
349                 pGTResult = poResult->getAs<types::GenericType>();
350             }
351             else if (pGT->isSparseBool() && pGTResult->isBool()) // [Bool SparseBool] => [SparseBool SparseBool]
352             {
353                 poResult = new types::SparseBool(*pGTResult->getAs<types::Bool>());
354                 pGTResult->killMe();
355                 pGTResult = poResult->getAs<types::GenericType>();
356             }
357
358             types::InternalType* pNewSize = AddElementToVariable(NULL, poResult, pGTResult->getRows() + pGT->getRows(), pGT->getCols());
359             types::InternalType* p = AddElementToVariable(pNewSize, pGT, pGTResult->getRows(), 0);
360             if (p != pNewSize)
361             {
362                 pNewSize->killMe();
363             }
364
365             // call overload
366             if (p == NULL)
367             {
368                 try
369                 {
370                    poResult = callOverloadMatrixExp(L"f", pGTResult, pGT);
371                 }
372                 catch (const InternalError& error)
373                 {
374                     throw error;
375                 }
376                 continue;
377             }
378
379             if (poResult != poRow)
380             {
381                 poRow->killMe();
382             }
383
384             if (p != poResult)
385             {
386                 poResult->killMe();
387                 poResult = p;
388             }
389         }
390
391         if (poResult)
392         {
393             setResult(poResult);
394         }
395         else
396         {
397             setResult(types::Double::Empty());
398         }
399     }
400     catch (const InternalError& error)
401     {
402         setResult(NULL);
403         throw error;
404     }
405 }
406
407 template<class T>
408 types::InternalType* RunVisitorT<T>::callOverloadMatrixExp(const std::wstring& strType, types::InternalType* _paramL, types::InternalType* _paramR)
409 {
410     types::typed_list in;
411     types::typed_list out;
412     types::Callable::ReturnValue Ret;
413
414     _paramL->IncreaseRef();
415     _paramR->IncreaseRef();
416
417     in.push_back(_paramL);
418     in.push_back(_paramR);
419
420     try
421     {
422         if (_paramR->isGenericType() && _paramR->getAs<types::GenericType>()->getDims() > 2)
423         {
424             Ret = Overload::call(L"%hm_" + strType + L"_hm", in, 1, out, true);
425         }
426         else
427         {
428             Ret = Overload::call(L"%" + _paramL->getAs<types::List>()->getShortTypeStr() + L"_" + strType + L"_" + _paramR->getAs<types::List>()->getShortTypeStr(), in, 1, out, true);
429         }
430     }
431     catch (const InternalError& error)
432     {
433         cleanInOut(in, out);
434         throw error;
435     }
436
437     if (Ret != types::Callable::OK)
438     {
439         cleanInOut(in, out);
440         throw InternalError(ConfigVariable::getLastErrorMessage());
441     }
442
443     cleanIn(in, out);
444
445     if (out.empty())
446     {
447         // TODO: avoid crash if out is empty but must return an error...
448         return NULL;
449     }
450
451     return out[0];
452 }
453
454 } /* namespace ast */