[differential_equations] bugs and tests fixed
[scilab.git] / scilab / modules / ast / src / cpp / types / macro.cpp
1 /*
2 *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3 *  Copyright (C) 2009-2009 - DIGITEO - Bruno JOFRET
4 *
5  * Copyright (C) 2012 - 2016 - Scilab Enterprises
6  *
7  * This file is hereby licensed under the terms of the GNU GPL v2.0,
8  * pursuant to article 5.3.4 of the CeCILL v.2.1.
9  * This file was originally licensed under the terms of the CeCILL v2.1,
10  * and continues to be available under such terms.
11  * For more information, see the COPYING file which you should have received
12  * along with this program.
13 *
14 */
15
16 #include <memory>
17 #include <sstream>
18 #include <cstdio>
19
20 #include "macro.hxx"
21 #include "list.hxx"
22 #include "context.hxx"
23 #include "symbol.hxx"
24 #include "scilabWrite.hxx"
25 #include "configvariable.hxx"
26 #include "serializervisitor.hxx"
27
28 extern "C"
29 {
30 #include "localization.h"
31 #include "Scierror.h"
32 #include "sciprint.h"
33 #include "sci_malloc.h"
34 #include "os_string.h"
35 }
36
37 namespace types
38 {
39 Macro::Macro(const std::wstring& _stName, std::list<symbol::Variable*>& _inputArgs, std::list<symbol::Variable*>& _outputArgs, ast::SeqExp &_body, const std::wstring& _stModule):
40     Callable(),
41     m_inputArgs(&_inputArgs), m_outputArgs(&_outputArgs), m_body(_body.clone()),
42     m_Nargin(symbol::Context::getInstance()->getOrCreate(symbol::Symbol(L"nargin"))),
43     m_Nargout(symbol::Context::getInstance()->getOrCreate(symbol::Symbol(L"nargout"))),
44     m_Varargin(symbol::Context::getInstance()->getOrCreate(symbol::Symbol(L"varargin"))),
45     m_Varargout(symbol::Context::getInstance()->getOrCreate(symbol::Symbol(L"varargout")))
46 {
47     setName(_stName);
48     setModule(_stModule);
49     bAutoAlloc = false;
50     m_pDblArgIn = new Double(1);
51     m_pDblArgIn->IncreaseRef(); //never delete
52     m_pDblArgOut = new Double(1);
53     m_pDblArgOut->IncreaseRef(); //never delete
54
55     m_body->setReturnable();
56     m_stPath = L"";
57 }
58
59 Macro::~Macro()
60 {
61     delete m_body;
62     m_pDblArgIn->DecreaseRef();
63     m_pDblArgIn->killMe();
64     m_pDblArgOut->DecreaseRef();
65     m_pDblArgOut->killMe();
66
67     if (m_inputArgs)
68     {
69         delete m_inputArgs;
70     }
71
72     if (m_outputArgs)
73     {
74         delete m_outputArgs;
75     }
76
77     for (const auto & sub : m_submacro)
78     {
79         sub.second->DecreaseRef();
80         sub.second->killMe();
81     }
82
83     m_submacro.clear();
84 }
85
86 void Macro::cleanCall(symbol::Context * pContext, int oldPromptMode)
87 {
88     //restore previous prompt mode
89     ConfigVariable::setPromptMode(oldPromptMode);
90     //close the current scope
91     pContext->scope_end();
92     ConfigVariable::macroFirstLine_end();
93 }
94
95 Macro* Macro::clone()
96 {
97     IncreaseRef();
98     return this;
99 }
100
101 void Macro::whoAmI()
102 {
103     std::cout << "types::Macro";
104 }
105
106 ast::SeqExp* Macro::getBody(void)
107 {
108     return m_body;
109 }
110
111 bool Macro::toString(std::wostringstream& ostr)
112 {
113     // get macro name
114     wchar_t* wcsVarName = NULL;
115     if (ostr.str() == SPACES_LIST)
116     {
117         wcsVarName = os_wcsdup(getName().c_str());
118     }
119     else
120     {
121         wcsVarName = os_wcsdup(ostr.str().c_str());
122     }
123
124     ostr.str(L"");
125     ostr << L"[";
126
127     // output arguments [a,b,c] = ....
128     if (m_outputArgs->empty() == false)
129     {
130         std::list<symbol::Variable*>::iterator OutArg = m_outputArgs->begin();
131         std::list<symbol::Variable*>::iterator OutArgfter = OutArg;
132         OutArgfter++;
133
134         for (; OutArgfter != m_outputArgs->end(); OutArgfter++)
135         {
136             ostr << (*OutArg)->getSymbol().getName();
137             ostr << ",";
138             OutArg++;
139         }
140
141         ostr << (*OutArg)->getSymbol().getName();
142     }
143
144     ostr << L"]";
145
146     // function name
147     ostr << L"=" << wcsVarName << L"(";
148
149     // input arguments function(a,b,c)
150     if (m_inputArgs->empty() == false)
151     {
152         std::list<symbol::Variable*>::iterator inArg = m_inputArgs->begin();
153         std::list<symbol::Variable*>::iterator inRagAfter = inArg;
154         inRagAfter++;
155
156         for (; inRagAfter != m_inputArgs->end(); inRagAfter++)
157         {
158             ostr << (*inArg)->getSymbol().getName();
159             ostr << ",";
160             inArg++;
161         }
162
163         ostr << (*inArg)->getSymbol().getName();
164     }
165
166     ostr << L")" << std::endl;
167
168     FREE(wcsVarName);
169     return true;
170 }
171
172 Callable::ReturnValue Macro::call(typed_list &in, optional_list &opt, int _iRetCount, typed_list &out)
173 {
174     bool bVarargout = false;
175     ReturnValue RetVal = Callable::OK;
176     symbol::Context *pContext = symbol::Context::getInstance();
177
178     //open a new scope
179     pContext->scope_begin();
180     //store the line number where is stored this macro in file.
181     ConfigVariable::macroFirstLine_begin(getFirstLine());
182
183     //check excepted and input/output parameters numbers
184     // Scilab Macro can be called with less than prototyped arguments,
185     // but not more execpts with varargin
186
187     // varargin management
188     if (m_inputArgs->size() > 0 && m_inputArgs->back()->getSymbol().getName() == L"varargin")
189     {
190         int iVarPos = static_cast<int>(in.size());
191         if (iVarPos > static_cast<int>(m_inputArgs->size()) - 1)
192         {
193             iVarPos = static_cast<int>(m_inputArgs->size()) - 1;
194         }
195
196         //add all standard variable in function context but not varargin
197         std::list<symbol::Variable*>::iterator itName = m_inputArgs->begin();
198         typed_list::const_iterator itValue = in.begin();
199         while (iVarPos > 0)
200         {
201             pContext->put(*itName, *itValue);
202             iVarPos--;
203             ++itName;
204             ++itValue;
205         }
206
207         //create varargin only if previous variable are assigned
208         optional_list::const_iterator it = opt.begin();
209         if (in.size() >= m_inputArgs->size() - 1)
210         {
211             //create and fill varargin
212             List* pL = new List();
213             while (itValue != in.end())
214             {
215                 if (*itValue != NULL)
216                 {
217                     pL->append(*itValue);
218                 }
219                 else
220                 {
221                     pL->append(it->second);
222                     it++;
223                 }
224
225                 itValue++;
226             }
227             pContext->put(m_Varargin, pL);
228         }
229     }
230     else if (in.size() > m_inputArgs->size())
231     {
232         if (m_inputArgs->size() == 0)
233         {
234             Scierror(999, _("Wrong number of input arguments: This function has no input argument.\n"));
235         }
236         else
237         {
238             Scierror(999, _("Wrong number of input arguments.\n"));
239         }
240
241         pContext->scope_end();
242         ConfigVariable::macroFirstLine_end();
243         return Callable::Error;
244     }
245     else
246     {
247         //assign value to variable in the new context
248         std::list<symbol::Variable*>::iterator i;
249         typed_list::const_iterator j;
250
251         for (i = m_inputArgs->begin(), j = in.begin(); j != in.end(); ++j, ++i)
252         {
253             if (*j)
254             {
255                 //prevent assignation of NULL value
256                 pContext->put(*i, *j);
257             }
258         }
259
260         //add optional paramter in current scope
261         optional_list::const_iterator it;
262         for (it = opt.begin() ; it != opt.end() ; it++)
263         {
264             pContext->put(symbol::Symbol(it->first), it->second);
265         }
266
267
268     }
269
270     // varargout management
271     //rules :
272     // varargout must be alone
273     // varargout is a list
274     // varargout can containt more items than caller need
275     // varargout must containt at leat caller needs
276     if (m_outputArgs->size() == 1 && m_outputArgs->back()->getSymbol().getName() == L"varargout")
277     {
278         bVarargout = true;
279         List* pL = new List();
280         pContext->put(m_Varargout, pL);
281     }
282
283     //common part with or without varargin/varargout
284
285     // Declare nargin & nargout in function context.
286     if (m_pDblArgIn->getRef() > 1)
287     {
288         m_pDblArgIn->DecreaseRef();
289         m_pDblArgIn = (Double*)m_pDblArgIn->clone();
290         m_pDblArgIn->IncreaseRef();
291     }
292     m_pDblArgIn->set(0, static_cast<double>(in.size()));
293
294     if (m_pDblArgOut->getRef() > 1)
295     {
296         m_pDblArgOut->DecreaseRef();
297         m_pDblArgOut = (Double*)m_pDblArgOut->clone();
298         m_pDblArgOut->IncreaseRef();
299     }
300     m_pDblArgOut->set(0, _iRetCount);
301
302     pContext->put(m_Nargin, m_pDblArgIn);
303     pContext->put(m_Nargout, m_pDblArgOut);
304
305
306     //add sub macro in current context
307     for (const auto & sub : m_submacro)
308     {
309         pContext->put(sub.first, sub.second);
310     }
311
312     //save current prompt mode
313     int oldVal = ConfigVariable::getPromptMode();
314     std::unique_ptr<ast::ConstVisitor> exec (ConfigVariable::getDefaultVisitor());
315     try
316     {
317         ConfigVariable::setPromptMode(-1);
318         m_body->accept(*exec);
319         //restore previous prompt mode
320         ConfigVariable::setPromptMode(oldVal);
321     }
322     catch (const ast::InternalError& ie)
323     {
324         cleanCall(pContext, oldVal);
325         throw ie;
326     }
327     catch (const ast::InternalAbort& ia)
328     {
329         cleanCall(pContext, oldVal);
330         throw ia;
331     }
332     // Normally, seqexp throws only SM so no need to catch SErr
333
334     //varargout management
335     if (bVarargout)
336     {
337         InternalType* pOut = pContext->get(m_Varargout);
338         if (pOut == NULL)
339         {
340             cleanCall(pContext, oldVal);
341             Scierror(999, _("Invalid index.\n"));
342             return Callable::Error;
343         }
344
345         if (pOut->isList() == false || pOut->getAs<List>()->getSize() == 0)
346         {
347             cleanCall(pContext, oldVal);
348             Scierror(999, _("Invalid index.\n"));
349             return Callable::Error;
350         }
351
352         List* pVarOut = pOut->getAs<List>();
353         const int size = std::min(pVarOut->getSize(), _iRetCount);
354         for (int i = 0 ; i < size ; ++i)
355         {
356             InternalType* pIT = pVarOut->get(i);
357             if (pIT->isListUndefined())
358             {
359                 for (int j = 0; j < i; ++j)
360                 {
361                     out[j]->DecreaseRef();
362                     out[j]->killMe();
363                 }
364                 out.clear();
365                 cleanCall(pContext, oldVal);
366
367                 Scierror(999, _("List element number %d is Undefined.\n"), i + 1);
368                 return Callable::Error;
369             }
370
371             pIT->IncreaseRef();
372             out.push_back(pIT);
373         }
374     }
375     else
376     {
377         //normal output management
378         for (std::list<symbol::Variable*>::iterator i = m_outputArgs->begin(); i != m_outputArgs->end() && _iRetCount; ++i, --_iRetCount)
379         {
380             InternalType * pIT = pContext->get(*i);
381             if (pIT)
382             {
383                 out.push_back(pIT);
384                 pIT->IncreaseRef();
385             }
386             else
387             {
388                 const int size = (const int)out.size();
389                 for (int j = 0; j < size; ++j)
390                 {
391                     out[j]->DecreaseRef();
392                     out[j]->killMe();
393                 }
394                 out.clear();
395                 cleanCall(pContext, oldVal);
396
397                 char* pstArgName = wide_string_to_UTF8((*i)->getSymbol().getName().c_str());
398                 char* pstMacroName = wide_string_to_UTF8(getName().c_str());
399                 Scierror(999, _("Undefined variable '%s' in function '%s'.\n"), pstArgName, pstMacroName);
400                 FREE(pstArgName);
401                 FREE(pstMacroName);
402                 return Callable::Error;
403             }
404         }
405     }
406
407     //close the current scope
408     cleanCall(pContext, oldVal);
409
410     for (typed_list::iterator i = out.begin(), end = out.end(); i != end; ++i)
411     {
412         (*i)->DecreaseRef();
413     }
414
415     return RetVal;
416 }
417
418 std::list<symbol::Variable*>* Macro::getInputs()
419 {
420     return m_inputArgs;
421 }
422
423 std::list<symbol::Variable*>* Macro::getOutputs()
424 {
425     return m_outputArgs;
426 }
427
428 int Macro::getNbInputArgument(void)
429 {
430     return (int)m_inputArgs->size();
431 }
432
433 int Macro::getNbOutputArgument(void)
434 {
435     if (m_outputArgs->size() == 1 && m_outputArgs->back()->getSymbol().getName() == L"varargout")
436     {
437         return -1;
438     }
439
440     return (int)m_outputArgs->size();
441 }
442
443 bool Macro::operator==(const InternalType& it)
444 {
445     if (const_cast<InternalType &>(it).isMacro() == false)
446     {
447         return false;
448     }
449
450     std::list<symbol::Variable*>* pInput = NULL;
451     std::list<symbol::Variable*>* pOutput = NULL;
452     types::Macro* pRight = const_cast<InternalType &>(it).getAs<types::Macro>();
453
454     //check inputs
455     pInput = pRight->getInputs();
456     if (pInput->size() != m_inputArgs->size())
457     {
458         return false;
459     }
460
461     std::list<symbol::Variable*>::iterator itOld = pInput->begin();
462     std::list<symbol::Variable*>::iterator itEndOld = pInput->end();
463     std::list<symbol::Variable*>::iterator itMacro = m_inputArgs->begin();
464
465     for (; itOld != itEndOld ; ++itOld, ++itMacro)
466     {
467         if ((*itOld)->getSymbol() != (*itMacro)->getSymbol())
468         {
469             return false;
470         }
471     }
472
473     //check outputs
474     pOutput = pRight->getOutputs();
475     if (pOutput->size() != m_outputArgs->size())
476     {
477         return false;
478     }
479
480     itOld = pOutput->begin();
481     itEndOld = pOutput->end();
482     itMacro = m_outputArgs->begin();
483
484     for (; itOld != itEndOld ; ++itOld, ++itMacro)
485     {
486         if ((*itOld)->getSymbol() != (*itMacro)->getSymbol())
487         {
488             return false;
489         }
490     }
491
492     ast::Exp* pExp = pRight->getBody();
493     ast::SerializeVisitor serialOld(pExp);
494     unsigned char* oldSerial = serialOld.serialize(false, false);
495     ast::SerializeVisitor serialMacro(m_body);
496     unsigned char* macroSerial = serialMacro.serialize(false, false);
497
498     //check buffer length
499     unsigned int oldSize = *((unsigned int*)oldSerial);
500     unsigned int macroSize = *((unsigned int*)macroSerial);
501     if (oldSize != macroSize)
502     {
503         free(oldSerial);
504         free(macroSerial);
505         return false;
506     }
507
508     bool ret = (memcmp(oldSerial, macroSerial, oldSize) == 0);
509
510     free(oldSerial);
511     free(macroSerial);
512
513     return ret;
514 }
515
516 void Macro::add_submacro(const symbol::Symbol& s, Macro* macro)
517 {
518     macro->IncreaseRef();
519     symbol::Context* ctx = symbol::Context::getInstance();
520     symbol::Variable* var = ctx->getOrCreate(s);
521     m_submacro[var] = macro;
522 }
523 }