JIT: improve forexp
[scilab.git] / scilab / modules / ast / src / cpp / jit / JITForExp.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2015 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include <cstdlib>
14
15 #include "JITScalars.hxx"
16 #include "JITArrayofs.hxx"
17 #include "JITVisitor.hxx"
18
19 namespace jit
20 {
21 void JITVisitor::visit(const ast::VarDec & e)
22 {
23
24 }
25
26 /*void JITVisitor::visit(const ast::ForExp & e)
27   {
28   const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
29   // TODO : handle for with an iterator
30   if (vd.getInit().isListExp())
31   {
32   const symbol::Symbol & symIterator = vd.getSymbol();
33   const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
34   if (le.getDecorator().getResult().getRange().isValid())
35   {
36   // for i = start:step:end...
37   // is equivalent to for (int64_t i = start; i < end + step; i += step)...
38
39   le.getStart().accept(*this);
40   llvm::Value * start = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
41   le.getStep().accept(*this);
42   llvm::Value * step = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
43   le.getEnd().accept(*this);
44   llvm::Value * end = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
45   end = builder.CreateAdd(end, step);
46
47   llvm::BasicBlock * cur_block = builder.GetInsertBlock();
48   llvm::BasicBlock * condBlock = llvm::BasicBlock::Create(context, "for_cond", function);
49   llvm::BasicBlock * loopBlock = llvm::BasicBlock::Create(context, "for_loop", function);
50   llvm::BasicBlock * afterBlock = llvm::BasicBlock::Create(context, "for_after", function);
51
52   blocks.emplace(condBlock, afterBlock);
53
54   llvm::Value * cmp_i1 = builder.CreateICmpSLT(start, end);
55   builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
56
57   builder.SetInsertPoint(loopBlock);
58   llvm::PHINode * i = builder.CreatePHI(getTy<int64_t>(), 2);
59   i->addIncoming(start, cur_block);
60   JITScilabPtr & it = variables.find(symIterator)->second;
61   it->storeData(*this, i);
62
63   e.getBody().accept(*this);
64   builder.CreateBr(condBlock);
65
66   builder.SetInsertPoint(condBlock);
67   llvm::Value * ipstp_i64 = builder.CreateAdd(i, step);
68   i->addIncoming(ipstp_i64, condBlock);
69   cmp_i1 = builder.CreateICmpSLT(ipstp_i64, end);
70   builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
71
72   builder.SetInsertPoint(afterBlock);
73   }
74   else
75   {
76   const ast::Exp & startE = le.getStart();
77   const ast::Exp & stepE = le.getStep();
78   const ast::Exp & endE = le.getEnd();
79
80   // a lot of for loops are like this: for i=1:N....
81   // start is an integer and step is by default 1 so we can use a classical
82   // loop with integer counter.
83
84   if (startE.isDoubleExp() && stepE.isDoubleExp())
85   {
86   const ast::DoubleExp & startDE = static_cast<const ast::DoubleExp &>(startE);
87   const ast::DoubleExp & stepDE = static_cast<const ast::DoubleExp &>(stepE);
88   int64_t start_i;
89   int64_t step_i;
90   if (analysis::tools::asInteger(startDE.getValue(), start_i) && analysis::tools::asInteger(stepDE.getValue(), step_i))
91   {
92   llvm::Value * start = getConstant(start_i);
93   llvm::Value * step = getConstant(std::abs(step_i));
94   le.getEnd().accept(*this);
95   llvm::Value * end = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
96
97   if (step_i > 0)
98   {
99   end = builder.CreateAdd(end, step);
100   }
101   else
102   {
103   end = builder.CreateSub(end, step);
104   }
105
106   llvm::BasicBlock * cur_block = builder.GetInsertBlock();
107   llvm::BasicBlock * condBlock = llvm::BasicBlock::Create(context, "for_cond", function);
108   llvm::BasicBlock * loopBlock = llvm::BasicBlock::Create(context, "for_loop", function);
109   llvm::BasicBlock * afterBlock = llvm::BasicBlock::Create(context, "for_after", function);
110
111   blocks.emplace(condBlock, afterBlock);
112
113   llvm::Value * cmp_i1;
114   if (step_i > 0)
115   {
116   cmp_i1 = builder.CreateICmpSLT(start, end);
117   }
118   else
119   {
120   cmp_i1 = builder.CreateICmpSGT(start, end);
121   }
122   builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
123
124   builder.SetInsertPoint(loopBlock);
125   llvm::PHINode * i = builder.CreatePHI(getTy<int64_t>(), 2);
126   i->addIncoming(start, cur_block);
127   JITScilabPtr & it = variables.find(symIterator)->second;
128   it->storeData(*this, i);
129
130   e.getBody().accept(*this);
131   builder.CreateBr(condBlock);
132
133   builder.SetInsertPoint(condBlock);
134   llvm::Value * ipstp_i64;
135   if (step_i > 0)
136   {
137   ipstp_i64 = builder.CreateAdd(i, step);
138   }
139   else
140   {
141   ipstp_i64 = builder.CreateSub(i, step);
142   }
143
144   i->addIncoming(ipstp_i64, condBlock);
145   if (step_i > 0)
146   {
147   cmp_i1 = builder.CreateICmpSLT(ipstp_i64, end);
148   }
149   else
150   {
151   cmp_i1 = builder.CreateICmpSGT(ipstp_i64, end);
152   }
153   builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
154
155   builder.SetInsertPoint(afterBlock);
156   }
157   }
158   }
159
160   }
161
162
163   // e.getBody().accept(*this);
164   }*/
165
166 void JITVisitor::visit(const ast::ForExp & e)
167 {
168     const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
169     // TODO : handle for with an iterator
170     if (vd.getInit().isListExp())
171     {
172         const symbol::Symbol & symIterator = vd.getSymbol();
173         const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
174
175         if (le.getDecorator().getResult().getType().type != analysis::TIType::EMPTY)
176         {
177             llvm::Value * start = nullptr;
178             llvm::Value * step = nullptr;
179             llvm::Value * end = nullptr;
180             const ast::Exp & startE = le.getStart();
181             const ast::Exp & stepE = le.getStep();
182             const ast::Exp & endE = le.getEnd();
183             const analysis::TIType & startTy = startE.getDecorator().getResult().getType();
184             const analysis::TIType & stepTy = stepE.getDecorator().getResult().getType();
185             const analysis::TIType & endTy = endE.getDecorator().getResult().getType();
186             const bool startTyIsSigned = startTy.isintegral() && startTy.issigned();
187             const bool stepTyIsSigned = stepTy.isintegral() && stepTy.issigned();
188             const bool endTyIsSigned = endTy.isintegral() && endTy.issigned();
189             bool constantStart = false;
190             bool constantStep = false;
191             bool constantEnd = false;
192             bool signStep;
193             bool integralStep = false;
194             bool integerIterator = false;
195             int64_t step_i;
196             int64_t start_i;
197             int64_t end_i;
198             double step_d;
199             double start_d;
200             double end_d;
201
202             if (stepE.isDoubleExp())
203             {
204                 // Step is constant
205                 constantStep = true;
206                 const ast::DoubleExp & stepDE = static_cast<const ast::DoubleExp &>(stepE);
207                 step_d = stepDE.getValue();
208                 signStep = step_d > 0;
209                 if (!signStep)
210                 {
211                     step_d = -step_d;
212                 }
213                 if (analysis::tools::asInteger(step_d, step_i))
214                 {
215                     integralStep = true;
216                 }
217             }
218             else
219             {
220                 // Step is not constant
221                 // Check if step is zero or not
222                 const analysis::TIType & ty = stepE.getDecorator().getResult().getType();
223                 if (ty.isintegral() || ty.isreal())
224                 {
225                     stepE.accept(*this);
226                     step = getResult()->loadData(*this);
227                 }
228                 else
229                 {
230                     // TODO: error
231                     return;
232                 }
233             }
234
235             if (startE.isDoubleExp())
236             {
237                 constantStart = true;
238                 const ast::DoubleExp & startDE = static_cast<const ast::DoubleExp &>(startE);
239                 start_d = startDE.getValue();
240                 if (analysis::tools::asInteger(start_d, start_i))
241                 {
242                     integerIterator = integralStep;
243                     if (integerIterator)
244                     {
245                         start = getConstant(start_i);
246                         step = getConstant(step_i);
247                     }
248                 }
249                 if (!start)
250                 {
251                     start = getConstant(start_d);
252                     if (!step)
253                     {
254                         step = getConstant(step_d);
255                     }
256                 }
257             }
258             else
259             {
260                 startE.accept(*this);
261                 start = getResult()->loadData(*this);
262                 if (!step)
263                 {
264                     step = getConstant(step_d);
265                 }
266             }
267
268             if (endE.isDoubleExp())
269             {
270                 constantEnd = true;
271                 const ast::DoubleExp & endDE = static_cast<const ast::DoubleExp &>(endE);
272                 end_d = endDE.getValue();
273                 if (analysis::tools::asInteger(end_d, end_i))
274                 {
275                     end = getConstant(end_i);
276                 }
277                 else
278                 {
279                     end = getConstant(end_d);
280                 }
281             }
282             else
283             {
284                 endE.accept(*this);
285                 end = getResult()->loadData(*this);
286             }
287
288             if (le.getDecorator().getResult().getRange().isValid())
289             {
290                 integerIterator = true;
291             }
292
293             if (integerIterator)
294             {
295                 start = Cast::cast<int64_t>(start, startTyIsSigned, *this);
296                 step = Cast::cast<int64_t>(step, stepTyIsSigned, *this);
297                 end = Cast::cast<int64_t>(end, endTyIsSigned, *this);
298                 llvm::Value * one = getConstant<int64_t>(1);
299                 if (constantStep)
300                 {
301                     if (signStep)
302                     {
303                         end = builder.CreateAdd(end, one);
304                     }
305                     else
306                     {
307                         end = builder.CreateSub(end, one);
308                     }
309                 }
310             }
311             else
312             {
313                 start = Cast::cast<double>(start, startTyIsSigned, *this);
314                 step = Cast::cast<double>(step, stepTyIsSigned, *this);
315                 end = Cast::cast<double>(end, endTyIsSigned, *this);
316             }
317
318             // integers values: for i = start:step:end...
319             // is equivalent to for (int64_t i = start; i < end + 1; i += step)...
320
321             // flaoting values: for i = start:step:end...
322             // is equivalent to for (double i = start; i <= end; i += step)...
323
324             llvm::BasicBlock * cond = llvm::BasicBlock::Create(context, "for_cond", function);
325             llvm::BasicBlock * loop = llvm::BasicBlock::Create(context, "for_loop", function);
326             llvm::BasicBlock * after = llvm::BasicBlock::Create(context, "for_after", function);
327
328             // To know how to break or continue the loop
329             blocks.emplace(cond, after);
330
331             llvm::Value * sign_step = nullptr;
332             llvm::Value * zero_i64 = getConstant<int64_t>(0);
333             llvm::Value * zero_dbl = getConstant<double>(0);
334             llvm::Value * abs = llvm::Intrinsic::getDeclaration(&getModule(), llvm::Intrinsic::fabs, getTy<double>());
335             llvm::Value * abs_arg[1];
336             llvm::Value * cmp_i1 = nullptr;
337
338             if (!constantStep)
339             {
340                 llvm::Value * cmp_zero;
341                 if (integerIterator)
342                 {
343                     cmp_zero = builder.CreateICmpEQ(step, zero_i64);
344                 }
345                 else
346                 {
347                     cmp_zero = builder.CreateFCmpOEQ(step, zero_dbl);
348                 }
349                 llvm::BasicBlock * precond = llvm::BasicBlock::Create(context, "", function);
350                 builder.CreateCondBr(cmp_zero, after, precond);
351                 builder.SetInsertPoint(precond);
352             }
353
354             if (!integerIterator)
355             {
356                 if (!constantStart)
357                 {
358                     abs_arg[0] = start;
359                     llvm::CallInst * abs_start = builder.CreateCall(abs, abs_arg);
360                     abs_start->setTailCall(true);
361                     cmp_i1 = builder.CreateFCmpUEQ(abs_start, getConstant<double>(std::numeric_limits<double>::infinity()));
362                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
363                     builder.CreateCondBr(cmp_i1, after, bb);
364                     builder.SetInsertPoint(bb);
365                 }
366
367                 if (!constantEnd)
368                 {
369                     abs_arg[0] = end;
370                     llvm::CallInst * abs_end = builder.CreateCall(abs, abs_arg);
371                     abs_end->setTailCall(true);
372                     cmp_i1 = builder.CreateFCmpUEQ(abs_end, getConstant<double>(std::numeric_limits<double>::infinity()));
373                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
374                     builder.CreateCondBr(cmp_i1, after, bb);
375                     builder.SetInsertPoint(bb);
376                 }
377             }
378
379             if (integerIterator)
380             {
381                 if (constantStep)
382                 {
383                     if (signStep)
384                     {
385                         cmp_i1 = builder.CreateICmpSLT(start, end);
386                     }
387                     else
388                     {
389                         cmp_i1 = builder.CreateICmpSGT(start, end);
390                     }
391                 }
392                 else
393                 {
394                     sign_step = builder.CreateICmpSGT(step, zero_i64);
395                     llvm::Value * cmp1 = builder.CreateICmpSLT(start, end);
396                     llvm::Value * cmp2 = builder.CreateICmpSGT(start, end);
397                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
398                 }
399             }
400             else
401             {
402                 if (constantStep)
403                 {
404                     if (signStep)
405                     {
406                         cmp_i1 = builder.CreateFCmpOLE(start, end);
407                     }
408                     else
409                     {
410                         cmp_i1 = builder.CreateFCmpOGE(start, end);
411                     }
412                 }
413                 else
414                 {
415                     abs_arg[0] = step;
416                     llvm::CallInst * abs_step = builder.CreateCall(abs, abs_arg);
417                     abs_step->setTailCall(true);
418                     cmp_i1 = builder.CreateFCmpUEQ(abs_step, getConstant<double>(std::numeric_limits<double>::infinity()));
419                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
420                     builder.CreateCondBr(cmp_i1, after, bb);
421                     builder.SetInsertPoint(bb);
422                     sign_step = builder.CreateFCmpOGT(step, zero_dbl);
423                     llvm::Value * cmp1 = builder.CreateFCmpOLE(start, end);
424                     llvm::Value * cmp2 = builder.CreateFCmpOGE(start, end);
425                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
426                 }
427             }
428
429             builder.CreateCondBr(cmp_i1, loop, after);
430
431             llvm::BasicBlock * cur_block = builder.GetInsertBlock();
432
433             builder.SetInsertPoint(loop);
434             llvm::PHINode * i;
435             if (integerIterator)
436             {
437                 i = builder.CreatePHI(getTy<int64_t>(), 2);
438             }
439             else
440             {
441                 i = builder.CreatePHI(getTy<double>(), 2);
442             }
443
444             i->addIncoming(start, cur_block);
445             JITScilabPtr & it = variables.find(symIterator)->second;
446             //it->storeData(*this, i);
447
448             // Visit the loop body
449             e.getBody().accept(*this);
450
451             builder.CreateBr(cond);
452
453             builder.SetInsertPoint(cond);
454             llvm::Value * i_step;
455             if (integerIterator)
456             {
457                 if (constantStep)
458                 {
459                     if (signStep)
460                     {
461                         i_step = builder.CreateAdd(i, step);
462                         cmp_i1 = builder.CreateICmpSLT(i_step, end);
463                     }
464                     else
465                     {
466                         i_step = builder.CreateSub(i, step);
467                         cmp_i1 = builder.CreateICmpSGT(i_step, end);
468                     }
469                 }
470                 else
471                 {
472                     i_step = builder.CreateAdd(i, step);
473                     llvm::Value * cmp1 = builder.CreateICmpSLT(i_step, end);
474                     llvm::Value * cmp2 = builder.CreateICmpSGT(i_step, end);
475                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
476                 }
477             }
478             else
479             {
480                 if (constantStep)
481                 {
482                     if (signStep)
483                     {
484                         i_step = builder.CreateFAdd(i, step);
485                         cmp_i1 = builder.CreateFCmpOLE(i_step, end);
486                     }
487                     else
488                     {
489                         i_step = builder.CreateFSub(i, step);
490                         cmp_i1 = builder.CreateFCmpOGE(i_step, end);
491                     }
492                 }
493                 else
494                 {
495                     i_step = builder.CreateFAdd(i, step);
496                     llvm::Value * cmp1 = builder.CreateFCmpOLE(i_step, end);
497                     llvm::Value * cmp2 = builder.CreateFCmpOGE(i_step, end);
498                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
499                 }
500             }
501
502             i->addIncoming(i_step, cond);
503
504             builder.CreateCondBr(cmp_i1, loop, after);
505
506             builder.SetInsertPoint(after);
507         }
508     }
509 }
510 }