Merge remote-tracking branch 'origin/master' into jit
[scilab.git] / scilab / modules / ast / src / cpp / jit / JITForExp.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2015 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include <cstdlib>
14
15 #include "JITScalars.hxx"
16 #include "JITArrayofs.hxx"
17 #include "JITVisitor.hxx"
18
19 namespace jit
20 {
21 void JITVisitor::visit(const ast::VarDec & e)
22 {
23
24 }
25
26 void JITVisitor::visit(const ast::ForExp & e)
27 {
28     const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
29     // TODO : handle for with an iterator
30     if (vd.getInit().isListExp())
31     {
32         const symbol::Symbol & symIterator = vd.getSymbol();
33         const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
34
35         if (le.getDecorator().getResult().getType().type != analysis::TIType::EMPTY)
36         {
37             llvm::Value * start = nullptr;
38             llvm::Value * step = nullptr;
39             llvm::Value * end = nullptr;
40             const ast::Exp & startE = le.getStart();
41             const ast::Exp & stepE = le.getStep();
42             const ast::Exp & endE = le.getEnd();
43             const analysis::TIType & startTy = startE.getDecorator().getResult().getType();
44             const analysis::TIType & stepTy = stepE.getDecorator().getResult().getType();
45             const analysis::TIType & endTy = endE.getDecorator().getResult().getType();
46             const bool startTyIsSigned = startTy.isintegral() && startTy.issigned();
47             const bool stepTyIsSigned = stepTy.isintegral() && stepTy.issigned();
48             const bool endTyIsSigned = endTy.isintegral() && endTy.issigned();
49             bool constantStart = false;
50             bool constantStep = false;
51             bool constantEnd = false;
52             bool signStep;
53             bool integralStep = false;
54             bool integerIterator = false;
55             int64_t step_i;
56             int64_t start_i;
57             int64_t end_i;
58             double step_d;
59             double start_d;
60             double end_d;
61
62             if (stepE.isDoubleExp())
63             {
64                 // Step is constant
65                 constantStep = true;
66                 const ast::DoubleExp & stepDE = static_cast<const ast::DoubleExp &>(stepE);
67                 step_d = stepDE.getValue();
68                 signStep = step_d > 0;
69                 if (!signStep)
70                 {
71                     step_d = -step_d;
72                 }
73                 if (analysis::tools::asInteger(step_d, step_i))
74                 {
75                     integralStep = true;
76                 }
77             }
78             else
79             {
80                 // Step is not constant
81                 // Check if step is zero or not
82                 const analysis::TIType & ty = stepE.getDecorator().getResult().getType();
83                 if (ty.isintegral() || ty.isreal())
84                 {
85                     stepE.accept(*this);
86                     step = getResult()->loadData(*this);
87                 }
88                 else
89                 {
90                     // TODO: error
91                     return;
92                 }
93             }
94
95             if (startE.isDoubleExp())
96             {
97                 constantStart = true;
98                 const ast::DoubleExp & startDE = static_cast<const ast::DoubleExp &>(startE);
99                 start_d = startDE.getValue();
100                 if (analysis::tools::asInteger(start_d, start_i))
101                 {
102                     integerIterator = integralStep;
103                     if (integerIterator)
104                     {
105                         start = getConstant(start_i);
106                         step = getConstant(step_i);
107                     }
108                 }
109                 if (!start)
110                 {
111                     start = getConstant(start_d);
112                     if (!step)
113                     {
114                         step = getConstant(step_d);
115                     }
116                 }
117             }
118             else
119             {
120                 startE.accept(*this);
121                 start = getResult()->loadData(*this);
122                 if (!step)
123                 {
124                     step = getConstant(step_d);
125                 }
126             }
127
128             if (endE.isDoubleExp())
129             {
130                 constantEnd = true;
131                 const ast::DoubleExp & endDE = static_cast<const ast::DoubleExp &>(endE);
132                 end_d = endDE.getValue();
133                 if (analysis::tools::asInteger(end_d, end_i))
134                 {
135                     end = getConstant(end_i);
136                 }
137                 else
138                 {
139                     end = getConstant(end_d);
140                 }
141             }
142             else
143             {
144                 endE.accept(*this);
145                 end = getResult()->loadData(*this);
146             }
147
148             if (le.getDecorator().getResult().getRange().isValid())
149             {
150                 integerIterator = true;
151             }
152
153             if (integerIterator)
154             {
155                 start = Cast::cast<int64_t>(start, startTyIsSigned, *this);
156                 step = Cast::cast<int64_t>(step, stepTyIsSigned, *this);
157                 end = Cast::cast<int64_t>(end, endTyIsSigned, *this);
158                 llvm::Value * one = getConstant<int64_t>(1);
159                 if (constantStep)
160                 {
161                     if (signStep)
162                     {
163                         end = builder.CreateAdd(end, one);
164                     }
165                     else
166                     {
167                         end = builder.CreateSub(end, one);
168                     }
169                 }
170             }
171             else
172             {
173                 start = Cast::cast<double>(start, startTyIsSigned, *this);
174                 step = Cast::cast<double>(step, stepTyIsSigned, *this);
175                 end = Cast::cast<double>(end, endTyIsSigned, *this);
176             }
177
178             // integers values: for i = start:step:end...
179             // is equivalent to for (int64_t i = start; i < end + 1; i += step)...
180
181             // flaoting values: for i = start:step:end...
182             // is equivalent to for (double i = start; i <= end; i += step)...
183
184             llvm::BasicBlock * cond = llvm::BasicBlock::Create(context, "for_cond", function);
185             llvm::BasicBlock * loop = llvm::BasicBlock::Create(context, "for_loop", function);
186             llvm::BasicBlock * after = llvm::BasicBlock::Create(context, "for_after", function);
187
188             // To know how to break or continue the loop
189             blocks.emplace(cond, after);
190
191             llvm::Value * sign_step = nullptr;
192             llvm::Value * zero_i64 = getConstant<int64_t>(0);
193             llvm::Value * zero_dbl = getConstant<double>(0);
194             llvm::Value * abs = llvm::Intrinsic::getDeclaration(&getModule(), llvm::Intrinsic::fabs, getTy<double>());
195             llvm::Value * abs_arg[1];
196             llvm::Value * cmp_i1 = nullptr;
197
198             if (!constantStep)
199             {
200                 llvm::Value * cmp_zero;
201                 if (integerIterator)
202                 {
203                     cmp_zero = builder.CreateICmpEQ(step, zero_i64);
204                 }
205                 else
206                 {
207                     cmp_zero = builder.CreateFCmpOEQ(step, zero_dbl);
208                 }
209                 llvm::BasicBlock * precond = llvm::BasicBlock::Create(context, "", function);
210                 builder.CreateCondBr(cmp_zero, after, precond);
211                 builder.SetInsertPoint(precond);
212             }
213
214             if (!integerIterator)
215             {
216                 if (!constantStart)
217                 {
218                     abs_arg[0] = start;
219                     llvm::CallInst * abs_start = builder.CreateCall(abs, abs_arg);
220                     abs_start->setTailCall(true);
221                     cmp_i1 = builder.CreateFCmpUEQ(abs_start, getConstant<double>(std::numeric_limits<double>::infinity()));
222                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
223                     builder.CreateCondBr(cmp_i1, after, bb);
224                     builder.SetInsertPoint(bb);
225                 }
226
227                 if (!constantEnd)
228                 {
229                     abs_arg[0] = end;
230                     llvm::CallInst * abs_end = builder.CreateCall(abs, abs_arg);
231                     abs_end->setTailCall(true);
232                     cmp_i1 = builder.CreateFCmpUEQ(abs_end, getConstant<double>(std::numeric_limits<double>::infinity()));
233                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
234                     builder.CreateCondBr(cmp_i1, after, bb);
235                     builder.SetInsertPoint(bb);
236                 }
237             }
238
239             if (integerIterator)
240             {
241                 if (constantStep)
242                 {
243                     if (signStep)
244                     {
245                         cmp_i1 = builder.CreateICmpSLT(start, end);
246                     }
247                     else
248                     {
249                         cmp_i1 = builder.CreateICmpSGT(start, end);
250                     }
251                 }
252                 else
253                 {
254                     sign_step = builder.CreateICmpSGT(step, zero_i64);
255                     llvm::Value * cmp1 = builder.CreateICmpSLT(start, end);
256                     llvm::Value * cmp2 = builder.CreateICmpSGT(start, end);
257                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
258                 }
259             }
260             else
261             {
262                 if (constantStep)
263                 {
264                     if (signStep)
265                     {
266                         cmp_i1 = builder.CreateFCmpOLE(start, end);
267                     }
268                     else
269                     {
270                         cmp_i1 = builder.CreateFCmpOGE(start, end);
271                     }
272                 }
273                 else
274                 {
275                     abs_arg[0] = step;
276                     llvm::CallInst * abs_step = builder.CreateCall(abs, abs_arg);
277                     abs_step->setTailCall(true);
278                     cmp_i1 = builder.CreateFCmpUEQ(abs_step, getConstant<double>(std::numeric_limits<double>::infinity()));
279                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
280                     builder.CreateCondBr(cmp_i1, after, bb);
281                     builder.SetInsertPoint(bb);
282                     sign_step = builder.CreateFCmpOGT(step, zero_dbl);
283                     llvm::Value * cmp1 = builder.CreateFCmpOLE(start, end);
284                     llvm::Value * cmp2 = builder.CreateFCmpOGE(start, end);
285                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
286                 }
287             }
288
289             bool hasPre = false;
290             if (const analysis::LoopDecoration * ld = e.getDecorator().getLoopDecoration())
291             {
292                 if (!ld->getClone().empty() || !ld->getPromotion().empty())
293                 {
294                     llvm::BasicBlock * pre = llvm::BasicBlock::Create(context, "for_pre", function);
295                     builder.CreateCondBr(cmp_i1, pre, after);
296                     builder.SetInsertPoint(pre);
297                     if (!ld->getClone().empty())
298                     {
299                         cloneSyms(e);
300                     }
301                     else
302                     {
303                         //promoteSyms(e);
304                     }
305                     builder.CreateBr(loop);
306                     hasPre = true;
307                 }
308             }
309
310             if (!hasPre)
311             {
312                 builder.CreateCondBr(cmp_i1, loop, after);
313             }
314
315             llvm::BasicBlock * cur_block = builder.GetInsertBlock();
316
317             builder.SetInsertPoint(loop);
318             llvm::PHINode * i;
319             if (integerIterator)
320             {
321                 i = builder.CreatePHI(getTy<int64_t>(), 2);
322             }
323             else
324             {
325                 i = builder.CreatePHI(getTy<double>(), 2);
326             }
327
328             i->addIncoming(start, cur_block);
329             JITScilabPtr & it = variables.find(symIterator)->second;
330             it->storeData(*this, i);
331
332             // Visit the loop body
333             e.getBody().accept(*this);
334
335             builder.CreateBr(cond);
336
337             builder.SetInsertPoint(cond);
338             llvm::Value * i_step;
339             if (integerIterator)
340             {
341                 if (constantStep)
342                 {
343                     if (signStep)
344                     {
345                         i_step = builder.CreateAdd(i, step);
346                         cmp_i1 = builder.CreateICmpSLT(i_step, end);
347                     }
348                     else
349                     {
350                         i_step = builder.CreateSub(i, step);
351                         cmp_i1 = builder.CreateICmpSGT(i_step, end);
352                     }
353                 }
354                 else
355                 {
356                     i_step = builder.CreateAdd(i, step);
357                     llvm::Value * cmp1 = builder.CreateICmpSLT(i_step, end);
358                     llvm::Value * cmp2 = builder.CreateICmpSGT(i_step, end);
359                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
360                 }
361             }
362             else
363             {
364                 if (constantStep)
365                 {
366                     if (signStep)
367                     {
368                         i_step = builder.CreateFAdd(i, step);
369                         cmp_i1 = builder.CreateFCmpOLE(i_step, end);
370                     }
371                     else
372                     {
373                         i_step = builder.CreateFSub(i, step);
374                         cmp_i1 = builder.CreateFCmpOGE(i_step, end);
375                     }
376                 }
377                 else
378                 {
379                     i_step = builder.CreateFAdd(i, step);
380                     llvm::Value * cmp1 = builder.CreateFCmpOLE(i_step, end);
381                     llvm::Value * cmp2 = builder.CreateFCmpOGE(i_step, end);
382                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
383                 }
384             }
385
386             i->addIncoming(i_step, cond);
387
388             builder.CreateCondBr(cmp_i1, loop, after);
389
390             builder.SetInsertPoint(after);
391         }
392     }
393 }
394
395 void JITVisitor::cloneSyms(const ast::Exp & e)
396 {
397     llvm::Type * types[] = { getTy<int8_t *>(), getTy<int8_t *>(), getTy<int64_t>(), getTy<int32_t>(), getTy<bool>() };
398     llvm::Value * __memcpy = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::memcpy, types);
399     llvm::Function * __new = static_cast<llvm::Function *>(module->getOrInsertFunction("new", llvm::FunctionType::get(getTy<int8_t *>(), llvm::ArrayRef<llvm::Type *>(getTy<uint64_t>()), false)));
400
401     for (const auto & sym : e.getDecorator().getLoopDecoration()->getClone())
402     {
403         // TODO: bad stuff here: we must add the mangling to the name !!
404         JITScilabPtr & ptr = variables.find(sym)->second;
405         llvm::Value * x = ptr->loadData(*this);
406         llvm::Value * r = ptr->loadRows(*this);
407         llvm::Value * c = ptr->loadCols(*this);
408         llvm::Value * rc = builder.CreateMul(r, c);
409         llvm::Value * size = builder.CreateMul(rc, getConstant<int64_t>(getTySizeInBytes(x)));
410         llvm::CallInst * dest = builder.CreateCall(__new, size);
411         dest->addAttribute(0, llvm::Attribute::NoAlias);
412         llvm::Value * src = builder.CreateBitCast(x, getTy<int8_t *>());
413         llvm::Value * memcpy_args[] = { dest, src, size, getConstant<int64_t>(getTySizeInBytes(x)), getBool(false) };
414         builder.CreateCall(__memcpy, memcpy_args);
415         ptr->storeData(*this, dest);
416     }
417
418     /*    if (const analysis::Clone * cl = e.getDecorator().getClone())
419         {
420             if (!cl->get().empty())
421             {
422                 llvm::Type * types[] = { getTy<int8_t *>(), getTy<int8_t *>(), getTy<int64_t>(), getTy<int32_t>(), getTy<bool>() };
423                 llvm::Value * __memcpy = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::memcpy, types);
424                 llvm::Function * __new = static_cast<llvm::Function *>(module->getOrInsertFunction("new", llvm::FunctionType::get(getTy<int8_t *>(), llvm::ArrayRef<llvm::Type *>(getTy<uint64_t>()), false)));
425                 __new->addAttribute(0, llvm::Attribute::NoAlias);
426
427                 for (const auto & sym : cl->get())
428                 {
429                     JITScilabPtr & ptr = variables.find(sym)->second;
430                     llvm::Value * x = ptr->loadData(*this);
431                     llvm::Value * r = ptr->loadRows(*this);
432                     llvm::Value * c = ptr->loadCols(*this);
433                     llvm::Value * rc = builder.CreateMul(r, c);
434                     llvm::Value * size = builder.CreateMul(rc, getConstant<int64_t>(getTySizeInBytes(x)));
435                     llvm::CallInst * dest = builder.CreateCall(__new, size);
436                     dest->addAttribute(0, llvm::Attribute::NoAlias);
437                     llvm::Value * src = builder.CreateBitCast(x, getTy<int8_t *>());
438                     llvm::Value * memcpy_args[] = { dest, src, size, getConstant<int64_t>(getTySizeInBytes(x)), getBool(false) };
439                     builder.CreateCall(__memcpy, memcpy_args);
440                     ptr->storeData(*this, dest);
441                 }
442             }
443         }*/
444
445 }
446 }