JIT: allow code vectorization and add a copy for shared vars
[scilab.git] / scilab / modules / ast / src / cpp / jit / JITForExp.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2015 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include <cstdlib>
14
15 #include "JITScalars.hxx"
16 #include "JITArrayofs.hxx"
17 #include "JITVisitor.hxx"
18
19 namespace jit
20 {
21 void JITVisitor::visit(const ast::VarDec & e)
22 {
23
24 }
25
26 void JITVisitor::visit(const ast::ForExp & e)
27 {
28     const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
29     // TODO : handle for with an iterator
30     if (vd.getInit().isListExp())
31     {
32         const symbol::Symbol & symIterator = vd.getSymbol();
33         const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
34
35         if (le.getDecorator().getResult().getType().type != analysis::TIType::EMPTY)
36         {
37             llvm::Value * start = nullptr;
38             llvm::Value * step = nullptr;
39             llvm::Value * end = nullptr;
40             const ast::Exp & startE = le.getStart();
41             const ast::Exp & stepE = le.getStep();
42             const ast::Exp & endE = le.getEnd();
43             const analysis::TIType & startTy = startE.getDecorator().getResult().getType();
44             const analysis::TIType & stepTy = stepE.getDecorator().getResult().getType();
45             const analysis::TIType & endTy = endE.getDecorator().getResult().getType();
46             const bool startTyIsSigned = startTy.isintegral() && startTy.issigned();
47             const bool stepTyIsSigned = stepTy.isintegral() && stepTy.issigned();
48             const bool endTyIsSigned = endTy.isintegral() && endTy.issigned();
49             bool constantStart = false;
50             bool constantStep = false;
51             bool constantEnd = false;
52             bool signStep;
53             bool integralStep = false;
54             bool integerIterator = false;
55             int64_t step_i;
56             int64_t start_i;
57             int64_t end_i;
58             double step_d;
59             double start_d;
60             double end_d;
61
62             if (stepE.isDoubleExp())
63             {
64                 // Step is constant
65                 constantStep = true;
66                 const ast::DoubleExp & stepDE = static_cast<const ast::DoubleExp &>(stepE);
67                 step_d = stepDE.getValue();
68                 signStep = step_d > 0;
69                 if (!signStep)
70                 {
71                     step_d = -step_d;
72                 }
73                 if (analysis::tools::asInteger(step_d, step_i))
74                 {
75                     integralStep = true;
76                 }
77             }
78             else
79             {
80                 // Step is not constant
81                 // Check if step is zero or not
82                 const analysis::TIType & ty = stepE.getDecorator().getResult().getType();
83                 if (ty.isintegral() || ty.isreal())
84                 {
85                     stepE.accept(*this);
86                     step = getResult()->loadData(*this);
87                 }
88                 else
89                 {
90                     // TODO: error
91                     return;
92                 }
93             }
94
95             if (startE.isDoubleExp())
96             {
97                 constantStart = true;
98                 const ast::DoubleExp & startDE = static_cast<const ast::DoubleExp &>(startE);
99                 start_d = startDE.getValue();
100                 if (analysis::tools::asInteger(start_d, start_i))
101                 {
102                     integerIterator = integralStep;
103                     if (integerIterator)
104                     {
105                         start = getConstant(start_i);
106                         step = getConstant(step_i);
107                     }
108                 }
109                 if (!start)
110                 {
111                     start = getConstant(start_d);
112                     if (!step)
113                     {
114                         step = getConstant(step_d);
115                     }
116                 }
117             }
118             else
119             {
120                 startE.accept(*this);
121                 start = getResult()->loadData(*this);
122                 if (!step)
123                 {
124                     step = getConstant(step_d);
125                 }
126             }
127
128             if (endE.isDoubleExp())
129             {
130                 constantEnd = true;
131                 const ast::DoubleExp & endDE = static_cast<const ast::DoubleExp &>(endE);
132                 end_d = endDE.getValue();
133                 if (analysis::tools::asInteger(end_d, end_i))
134                 {
135                     end = getConstant(end_i);
136                 }
137                 else
138                 {
139                     end = getConstant(end_d);
140                 }
141             }
142             else
143             {
144                 endE.accept(*this);
145                 end = getResult()->loadData(*this);
146             }
147
148             if (le.getDecorator().getResult().getRange().isValid())
149             {
150                 integerIterator = true;
151             }
152
153             if (integerIterator)
154             {
155                 start = Cast::cast<int64_t>(start, startTyIsSigned, *this);
156                 step = Cast::cast<int64_t>(step, stepTyIsSigned, *this);
157                 end = Cast::cast<int64_t>(end, endTyIsSigned, *this);
158                 llvm::Value * one = getConstant<int64_t>(1);
159                 if (constantStep)
160                 {
161                     if (signStep)
162                     {
163                         end = builder.CreateAdd(end, one);
164                     }
165                     else
166                     {
167                         end = builder.CreateSub(end, one);
168                     }
169                 }
170             }
171             else
172             {
173                 start = Cast::cast<double>(start, startTyIsSigned, *this);
174                 step = Cast::cast<double>(step, stepTyIsSigned, *this);
175                 end = Cast::cast<double>(end, endTyIsSigned, *this);
176             }
177
178             // integers values: for i = start:step:end...
179             // is equivalent to for (int64_t i = start; i < end + 1; i += step)...
180
181             // flaoting values: for i = start:step:end...
182             // is equivalent to for (double i = start; i <= end; i += step)...
183
184             llvm::BasicBlock * cond = llvm::BasicBlock::Create(context, "for_cond", function);
185             llvm::BasicBlock * loop = llvm::BasicBlock::Create(context, "for_loop", function);
186             llvm::BasicBlock * after = llvm::BasicBlock::Create(context, "for_after", function);
187
188             // To know how to break or continue the loop
189             blocks.emplace(cond, after);
190
191             llvm::Value * sign_step = nullptr;
192             llvm::Value * zero_i64 = getConstant<int64_t>(0);
193             llvm::Value * zero_dbl = getConstant<double>(0);
194             llvm::Value * abs = llvm::Intrinsic::getDeclaration(&getModule(), llvm::Intrinsic::fabs, getTy<double>());
195             llvm::Value * abs_arg[1];
196             llvm::Value * cmp_i1 = nullptr;
197
198             if (!constantStep)
199             {
200                 llvm::Value * cmp_zero;
201                 if (integerIterator)
202                 {
203                     cmp_zero = builder.CreateICmpEQ(step, zero_i64);
204                 }
205                 else
206                 {
207                     cmp_zero = builder.CreateFCmpOEQ(step, zero_dbl);
208                 }
209                 llvm::BasicBlock * precond = llvm::BasicBlock::Create(context, "", function);
210                 builder.CreateCondBr(cmp_zero, after, precond);
211                 builder.SetInsertPoint(precond);
212             }
213
214             if (!integerIterator)
215             {
216                 if (!constantStart)
217                 {
218                     abs_arg[0] = start;
219                     llvm::CallInst * abs_start = builder.CreateCall(abs, abs_arg);
220                     abs_start->setTailCall(true);
221                     cmp_i1 = builder.CreateFCmpUEQ(abs_start, getConstant<double>(std::numeric_limits<double>::infinity()));
222                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
223                     builder.CreateCondBr(cmp_i1, after, bb);
224                     builder.SetInsertPoint(bb);
225                 }
226
227                 if (!constantEnd)
228                 {
229                     abs_arg[0] = end;
230                     llvm::CallInst * abs_end = builder.CreateCall(abs, abs_arg);
231                     abs_end->setTailCall(true);
232                     cmp_i1 = builder.CreateFCmpUEQ(abs_end, getConstant<double>(std::numeric_limits<double>::infinity()));
233                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
234                     builder.CreateCondBr(cmp_i1, after, bb);
235                     builder.SetInsertPoint(bb);
236                 }
237             }
238
239             if (integerIterator)
240             {
241                 if (constantStep)
242                 {
243                     if (signStep)
244                     {
245                         cmp_i1 = builder.CreateICmpSLT(start, end);
246                     }
247                     else
248                     {
249                         cmp_i1 = builder.CreateICmpSGT(start, end);
250                     }
251                 }
252                 else
253                 {
254                     sign_step = builder.CreateICmpSGT(step, zero_i64);
255                     llvm::Value * cmp1 = builder.CreateICmpSLT(start, end);
256                     llvm::Value * cmp2 = builder.CreateICmpSGT(start, end);
257                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
258                 }
259             }
260             else
261             {
262                 if (constantStep)
263                 {
264                     if (signStep)
265                     {
266                         cmp_i1 = builder.CreateFCmpOLE(start, end);
267                     }
268                     else
269                     {
270                         cmp_i1 = builder.CreateFCmpOGE(start, end);
271                     }
272                 }
273                 else
274                 {
275                     abs_arg[0] = step;
276                     llvm::CallInst * abs_step = builder.CreateCall(abs, abs_arg);
277                     abs_step->setTailCall(true);
278                     cmp_i1 = builder.CreateFCmpUEQ(abs_step, getConstant<double>(std::numeric_limits<double>::infinity()));
279                     llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
280                     builder.CreateCondBr(cmp_i1, after, bb);
281                     builder.SetInsertPoint(bb);
282                     sign_step = builder.CreateFCmpOGT(step, zero_dbl);
283                     llvm::Value * cmp1 = builder.CreateFCmpOLE(start, end);
284                     llvm::Value * cmp2 = builder.CreateFCmpOGE(start, end);
285                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
286                 }
287             }
288
289             bool hasPre = false;
290             if (const analysis::Clone * cl = e.getDecorator().getClone())
291             {
292                 if (!cl->get().empty())
293                 {
294                     llvm::BasicBlock * pre = llvm::BasicBlock::Create(context, "for_pre", function);
295                     builder.CreateCondBr(cmp_i1, pre, after);
296                     builder.SetInsertPoint(pre);
297                     cloneSyms(e);
298                     builder.CreateBr(loop);
299                     hasPre = true;
300                 }
301             }
302
303             if (!hasPre)
304             {
305                 builder.CreateCondBr(cmp_i1, loop, after);
306             }
307
308             llvm::BasicBlock * cur_block = builder.GetInsertBlock();
309
310             builder.SetInsertPoint(loop);
311             llvm::PHINode * i;
312             if (integerIterator)
313             {
314                 i = builder.CreatePHI(getTy<int64_t>(), 2);
315             }
316             else
317             {
318                 i = builder.CreatePHI(getTy<double>(), 2);
319             }
320
321             i->addIncoming(start, cur_block);
322             JITScilabPtr & it = variables.find(symIterator)->second;
323             it->storeData(*this, i);
324
325             // Visit the loop body
326             e.getBody().accept(*this);
327
328             builder.CreateBr(cond);
329
330             builder.SetInsertPoint(cond);
331             llvm::Value * i_step;
332             if (integerIterator)
333             {
334                 if (constantStep)
335                 {
336                     if (signStep)
337                     {
338                         i_step = builder.CreateAdd(i, step);
339                         cmp_i1 = builder.CreateICmpSLT(i_step, end);
340                     }
341                     else
342                     {
343                         i_step = builder.CreateSub(i, step);
344                         cmp_i1 = builder.CreateICmpSGT(i_step, end);
345                     }
346                 }
347                 else
348                 {
349                     i_step = builder.CreateAdd(i, step);
350                     llvm::Value * cmp1 = builder.CreateICmpSLT(i_step, end);
351                     llvm::Value * cmp2 = builder.CreateICmpSGT(i_step, end);
352                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
353                 }
354             }
355             else
356             {
357                 if (constantStep)
358                 {
359                     if (signStep)
360                     {
361                         i_step = builder.CreateFAdd(i, step);
362                         cmp_i1 = builder.CreateFCmpOLE(i_step, end);
363                     }
364                     else
365                     {
366                         i_step = builder.CreateFSub(i, step);
367                         cmp_i1 = builder.CreateFCmpOGE(i_step, end);
368                     }
369                 }
370                 else
371                 {
372                     i_step = builder.CreateFAdd(i, step);
373                     llvm::Value * cmp1 = builder.CreateFCmpOLE(i_step, end);
374                     llvm::Value * cmp2 = builder.CreateFCmpOGE(i_step, end);
375                     cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
376                 }
377             }
378
379             i->addIncoming(i_step, cond);
380
381             builder.CreateCondBr(cmp_i1, loop, after);
382
383             builder.SetInsertPoint(after);
384         }
385     }
386 }
387
388 void JITVisitor::cloneSyms(const ast::Exp & e)
389 {
390     if (const analysis::Clone * cl = e.getDecorator().getClone())
391     {
392         if (!cl->get().empty())
393         {
394             llvm::Type * types[] = { getTy<int8_t *>(), getTy<int8_t *>(), getTy<int64_t>(), getTy<int32_t>(), getTy<bool>() };
395             llvm::Value * __memcpy = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::memcpy, types);
396             llvm::Function * __new = static_cast<llvm::Function *>(module->getOrInsertFunction("new", llvm::FunctionType::get(getTy<int8_t *>(), llvm::ArrayRef<llvm::Type *>(getTy<uint64_t>()), false)));
397             __new->addAttribute(0, llvm::Attribute::NoAlias);
398
399             for (const auto & sym : cl->get())
400             {
401                 JITScilabPtr & ptr = variables.find(sym)->second;
402                 llvm::Value * x = ptr->loadData(*this);
403                 llvm::Value * r = ptr->loadRows(*this);
404                 llvm::Value * c = ptr->loadCols(*this);
405                 llvm::Value * rc = builder.CreateMul(r, c);
406                 llvm::Value * size = builder.CreateMul(rc, getConstant<int64_t>(getTySizeInBytes(x)));
407                 llvm::CallInst * dest = builder.CreateCall(__new, size);
408                 dest->addAttribute(0, llvm::Attribute::NoAlias);
409                 llvm::Value * src = builder.CreateBitCast(x, getTy<int8_t *>());
410                 llvm::Value * memcpy_args[] = { dest, src, size, getConstant<int64_t>(getTySizeInBytes(x)), getBool(false) };
411                 builder.CreateCall(__memcpy, memcpy_args);
412                 ptr->storeData(*this, dest);
413             }
414         }
415     }
416
417 }
418 }