remove debug msg and remove NDEBUG flag to llvm CFLAGS
[scilab.git] / scilab / modules / ast / src / cpp / jit / JITVisitor.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
4  *
5  *  This file must be used under the terms of the CeCILL.
6  *  This source file is licensed as described in the file COPYING, which
7  *  you should have received as part of this distribution.  The terms
8  *  are also available at
9  *  http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
10  *
11  */
12
13 #include "JITValues.hxx"
14 #include "JITVisitor.hxx"
15 #include "jit_operations.hxx"
16
17 namespace jit
18 {
19 const bool JITVisitor::__init__ = InitializeLLVM();
20 llvm::Value * const JITVisitor::ONE = llvm::ConstantInt::get(getLLVMTy<int>(), int(1));
21 llvm::Value * const JITVisitor::TWO = llvm::ConstantInt::get(getLLVMTy<int>(), int(2));
22 llvm::Value * const JITVisitor::THREE = llvm::ConstantInt::get(getLLVMTy<int>(), int(3));
23 llvm::Value * const JITVisitor::FOUR = llvm::ConstantInt::get(getLLVMTy<int>(), int(4));
24
25 JITVisitor::JITVisitor(const analysis::AnalysisVisitor & _analysis) : ast::ConstVisitor(),
26     analysis(_analysis),
27     context(llvm::getGlobalContext()),
28     module("JIT module", context),
29     engine(InitializeEngine(&module)),
30     FPM(initFPM(&module, engine)),
31     function(llvm::cast<llvm::Function>(module.getOrInsertFunction("jit_main", getLLVMTy<void>(context), nullptr))),
32     builder(context),
33     uintptrType(getPtrAsIntTy(module, context)),
34     _result(nullptr)
35 {
36     module.setDataLayout(engine->getDataLayout()->getStringRepresentation());
37     llvm::BasicBlock * BB = llvm::BasicBlock::Create(context, "EntryBlock", function);
38     builder.SetInsertPoint(BB);
39
40     symbol::Context * ctxt = symbol::Context::getInstance();
41     const std::set<symbol::Symbol> & SymRead = analysis.get_read();
42     for (std::set<symbol::Symbol>::const_iterator it = SymRead.cbegin(), end = SymRead.cend(); it != end; ++it)
43     {
44         const std::wstring & name = it->name_get();
45         const std::string _name(name.begin(), name.end());
46         symbol::Variable * var = ctxt->getOrCreate(*it);
47         types::InternalType * pIT = symbol::Context::getInstance()->get(var);
48
49         if (pIT)
50         {
51             symMap3.emplace(*it, std::shared_ptr<JITVal>(JITVal::get(*this, pIT, false, _name)));
52         }
53     }
54
55     const std::set<symbol::Symbol> & SymWrite = analysis.get_write();
56     for (std::set<symbol::Symbol>::const_iterator it = SymWrite.cbegin(), end = SymWrite.cend(); it != end; ++it)
57     {
58         const std::wstring & name = it->name_get();
59         const std::string _name(name.begin(), name.end());
60         symbol::Variable * var = ctxt->getOrCreate(*it);
61         types::InternalType * pIT = symbol::Context::getInstance()->get(var);
62
63         if (pIT)
64         {
65             symMap3.emplace(*it, std::shared_ptr<JITVal>(JITVal::get(*this, pIT, true, _name)));
66         }
67     }
68
69 }
70
71 void JITVisitor::run()
72 {
73     // on reinjecte les resultats ds l'environnement a=1;jit("a=2");
74     symbol::Context * ctxt = symbol::Context::getInstance();
75     llvm::Value * llvmCtxt = getPointer(ctxt);
76     //llvm::Value * toCall_S = getPointer(reinterpret_cast<void *>(&jit::putInContext_S<Double, double>), getLLVMPtrFuncTy<void, char *, char *, double>(context));
77     llvm::Value * toCall_M = module.getOrInsertFunction("putInContext_M_D_ds", getLLVMFuncTy<void, char *, char *, double *, int , int>(context));
78     //llvm::Value * toCall_M = getPointer(reinterpret_cast<void *>(&jit::putInContext_M<Double, double*>), getLLVMPtrFuncTy<void, char *, char *, double *, int, int>(context));
79
80     llvm::Value * toCall_S = module.getOrInsertFunction("putInContext_S_D_d", getLLVMFuncTy<void, char *, char *, double>(context));
81
82     for (JITSymbolMap::const_iterator i = symMap3.begin(), end = symMap3.end(); i != end; ++i)
83     {
84         symbol::Variable * var = ctxt->getOrCreate(i->first);
85         llvm::Value * llvmVar = getPointer(var);
86         if (i->second.get()->is_scalar())
87         {
88             builder.CreateCall3(toCall_S, llvmCtxt, llvmVar, i->second.get()->load(*this));
89         }
90         else
91         {
92             i->second.get()->load(*this)->dump();
93             builder.CreateCall5(toCall_M, llvmCtxt, llvmVar, i->second.get()->load(*this), i->second.get()->loadR(*this), i->second.get()->loadC(*this));
94         }
95     }
96
97     builder.CreateRetVoid();
98
99     //dump();
100
101     for (llvm::Module::iterator it = module.begin(), end = module.end(); it != end; ++it)
102     {
103         FPM.run(*it);
104     }
105
106     //dump();
107
108     engine->finalizeObject();
109
110     //getLLVMTy<const char * const>(context)->dump();
111     //getLLVMTy<int (*)(double)>(context)->dump();
112
113     //foo<int (*) (double)>(context)->dump();
114
115     reinterpret_cast<void (*)()>(engine->getFunctionAddress("jit_main"))();
116 }
117
118 void JITVisitor::dump() const
119 {
120     module.dump();
121     //function->dump();
122 }
123
124 template<>
125 llvm::Value * JITVisitor::getConstant<double>(const double val)
126 {
127     llvm::Value * v = llvm::ConstantFP::get(context, llvm::APFloat(val));
128     return v;
129 }
130
131 void JITVisitor::visit(const ast::SimpleVar &e)
132 {
133     /*                symbol::Symbol & sym = e.name_get();
134                       std::map<symbol::Symbol, llvm::Value *>::iterator i = symMap.find(sym);
135                       if (i != symMap.end())
136                       {
137                       if (llvm::isa<llvm::AllocaInst>(i->second))
138                       {
139                       llvm::LoadInst * tmp = builder.CreateLoad(llvm::cast<llvm::AllocaInst>(i->second));
140                       tmp->setAlignment(sizeof(double));
141                       result_set(tmp);
142                       }
143                       else
144                       {
145                       result_set(i->second);
146                       }
147                       }
148                       else
149                       {
150                       std::wcout << L"que faire...=" << sym.name_get() << std::endl;
151                       }
152     */
153     /*              symbol::Symbol & sym = e.name_get();
154                     std::map<symbol::Symbol, JITVal>::iterator i = symMap2.find(sym);
155                     if (i != symMap2.end())
156                     {
157                     llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), 1);
158                     result_set(JITVal(r, r, i->second.load(builder)));
159                     }
160                     else
161                     {
162                     types::Double * pIT = static_cast<Double *>(symbol::Context::getInstance()->get(((ast::SimpleVar&)e).stack_get()));
163                     llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), pIT->getRows());
164                     llvm::Value * c = llvm::Cou onstantInt::get(getLLVMTy<int>(context), pIT->getCols());
165                     llvm::Value * ptr = getPointer(pIT->get(), getLLVMTy<double *>(context));
166
167                     result_set(JITVal(r, c, ptr));
168
169                     //std::wcout << L"que faire...=" << sym.name_get() << std::endl;
170                     }
171     */
172     symbol::Symbol & sym = e.name_get();
173     JITSymbolMap::iterator i = symMap3.find(sym);
174     if (i != symMap3.end())
175     {
176         result_set(i->second);
177     }
178     else
179     {
180         const std::wstring & name = sym.name_get();
181         const std::string _name(name.begin(), name.end());
182         /*types::InternalType * pIT = symbol::Context::getInstance()->get(((ast::SimpleVar&)e).stack_get());
183
184         result_set(std::shared_ptr<JITVal>(JITVal::get(*this, pIT, _name)));*/
185         throw ast::ScilabError("Variable not declared before JIT: " + _name);
186     }
187 }
188
189 void JITVisitor::visit(const ast::DollarVar &e) //a=[1 2;3 4];b=[5 6;7 8];jit("a/b")
190 {
191
192 }
193
194 void JITVisitor::visit(const ast::ColonVar &e)
195 {
196
197 }
198
199 void JITVisitor::visit(const ast::ArrayListVar &e)
200 {
201
202 }
203
204 void JITVisitor::visit(const ast::IntExp &e)
205 {
206
207 }
208
209 void JITVisitor::visit(const ast::FloatExp &e)
210 {
211
212 }
213
214 void JITVisitor::visit(const ast::DoubleExp &e)
215 {
216     result_set(std::shared_ptr<JITVal>(new JITScalarVal<double>(*this, e.value_get(), false)));
217 }
218
219 void JITVisitor::visit(const ast::BoolExp &e)
220 {
221
222 }
223
224 void JITVisitor::visit(const ast::StringExp &e)
225 {
226
227 }
228
229 void JITVisitor::visit(const ast::CommentExp &e)
230 {
231     // ignored
232 }
233
234 void JITVisitor::visit(const ast::NilExp &e)
235 {
236
237 }
238
239 void JITVisitor::visit(const ast::CallExp &e)
240 {
241
242 }
243
244 void JITVisitor::visit(const ast::CellCallExp &e)
245 {
246
247 }
248
249 void JITVisitor::visit(const ast::OpExp &e)
250 {
251     e.left_get().accept(*this);
252     std::shared_ptr<JITVal> pITL = result_get();
253
254     /*getting what to assign*/
255     e.right_get().accept(*this);
256     std::shared_ptr<JITVal> & pITR = result_get();
257
258     llvm::Value * pResult = NULL;
259
260     switch (e.oper_get())
261     {
262         case ast::OpExp::plus:
263         {
264             if (pITL.get()->is_scalar())
265             {
266                 result_set(add_D_D(pITL, pITR, *this));
267             }
268             else
269             {
270                 result_set(add_M_M(pITL, pITR, *this));
271             }
272             return;
273         }
274         case ast::OpExp::minus:
275         {
276             if (pITL.get()->is_scalar())
277             {
278                 result_set(sub_D_D(pITL, pITR, *this));
279             }
280             else
281             {
282                 result_set(sub_M_M(pITL, pITR, *this));
283             }
284             return;//a=1;b=1;for i=1:2:23;a=a+i*3+b;b=b-i*a;end;
285         }
286         case ast::OpExp::times:
287         {
288             if (pITL.get()->is_scalar())
289             {
290                 result_set(dotmul_D_D(pITL, pITR, *this));
291             }
292             else
293             {
294                 result_set(dotmul_M_M(pITL, pITR, *this));
295             }
296             return;
297         }
298         default:
299             if (pITL.get()->is_scalar())
300             {
301                 result_set(add_D_D(pITL, pITR, *this));
302             }
303             else
304             {
305                 result_set(add_M_M(pITL, pITR, *this));
306             }
307             return;
308     }
309
310     //llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), 1);
311     //result_set(JITVal(r, r, pResult));
312 }
313
314 void JITVisitor::visit(const ast::LogicalOpExp &e)
315 {
316
317 }
318
319 void JITVisitor::visit(const ast::AssignExp &e)
320 {
321     if (e.left_exp_get().is_simple_var())
322     {
323         ast::SimpleVar & pVar = static_cast<ast::SimpleVar &>(e.left_exp_get());
324
325         e.right_exp_get().accept(*this);
326         std::shared_ptr<JITVal> & pITR = result_get();
327         llvm::Value * alloca = nullptr;
328         JITSymbolMap::const_iterator i = symMap3.find(pVar.name_get());
329
330         if (i != symMap3.end())
331         {
332             i->second.get()->store(*pITR.get(), *this);
333         }
334         else
335         {
336             const std::wstring & name = pVar.name_get().name_get();
337             const std::string _name(name.begin(), name.end());
338             // TODO: virer ce truc... le param <double> est force...
339             JITVal * jitV = new JITScalarVal<double>(*this, pITR.get(), _name);
340             symMap3.emplace(pVar.name_get(), std::shared_ptr<JITVal>(jitV));
341         }
342
343         result_set(std::shared_ptr<JITVal>(nullptr));
344     }
345 }
346
347 void JITVisitor::visit(const ast::IfExp &e)
348 {
349
350 }
351
352 void JITVisitor::visit(const ast::WhileExp &e)
353 {
354
355 }
356
357 void JITVisitor::visit(const ast::ForExp &e)
358 {
359     //e.vardec_get().accept(*this);
360     const ast::VarDec & vardec = e.vardec_get();
361     symbol::Symbol & varName = vardec.name_get();
362     const ast::Exp & init = vardec.init_get();
363
364     if (init.is_list_exp())
365     {
366         const ast::ListExp & list = static_cast<const ast::ListExp &>(init);
367         const double * list_values = list.get_values();
368         llvm::Value * start = nullptr, * step, * end;
369         bool use_int = false;
370         bool use_uint = false;
371         bool inc = true;
372         bool known_step = false;
373
374         if (!ISNAN(list_values[0]) && !ISNAN(list_values[1]) && !ISNAN(list_values[2]))
375         {
376             const double tstart = std::trunc(list_values[0]);
377             const double tstep = std::trunc(list_values[1]);
378             const double tend = std::trunc(list_values[2]);
379
380             inc = list_values[1] >= 0;
381             known_step = true;
382
383             if ((tstart == list_values[0]) && (tstep == list_values[1]))
384             {
385                 if (tstart >= 0 && tstep >= 0 && tstart <= list_values[2])
386                 {
387                     // we can use an unsigned int but take care to overflow...
388                     double k = std::floor(((double)std::numeric_limits<uint64_t>::max() - tstart) / tstep);
389                     if ((k * tstep + tstart) >= tend)
390                     {
391                         // no overflow
392                         start = getConstant((uint64_t)tstart);
393                         step = getConstant((uint64_t)tstep);
394                         end = getConstant((uint64_t)tend);
395
396                         use_int = true;
397                         use_uint = true;
398                     }
399                 }
400                 else
401                 {
402                     double k = std::floor(((double)(inc ? std::numeric_limits<int64_t>::max() : std::numeric_limits<int64_t>::min()) - tstart) / tstep);
403                     if ((inc && (k * tstep + tstart) >= tend) || (!inc && (k * tstep + tstart) <= tend))
404                     {
405                         // no overflow
406                         start = getConstant((int64_t)tstart);
407                         step = getConstant((int64_t)tstep);
408                         end = getConstant((int64_t)tend);
409
410                         use_int = true;
411                     }
412                 }
413             }
414             else
415             {
416                 start = getConstant(list_values[0]);
417                 step = getConstant(list_values[1]);
418                 end = getConstant(list_values[2]);
419             }
420         }
421
422         if (!start)
423         {
424             if (!ISNAN(list_values[0]))
425             {
426                 start = getConstant(list_values[0]);
427             }
428             else
429             {
430                 list.start_get().accept(*this);
431                 start = result_get().get()->load(*this);
432             }
433
434             if (!ISNAN(list_values[1]))
435             {
436                 step = getConstant(list_values[1]);
437                 inc = list_values[1] >= 0;
438                 known_step = true;
439             }
440             else
441             {
442                 list.step_get().accept(*this);
443                 step = result_get().get()->load(*this);
444             }
445
446             if (!ISNAN(list_values[2]))
447             {
448                 end = getConstant(list_values[2]);
449             }
450             else
451             {
452                 list.end_get().accept(*this);
453                 end = result_get().get()->load(*this);
454             }
455         }
456
457         llvm::BasicBlock * BBBody = llvm::BasicBlock::Create(context, "for_body", function);
458         llvm::BasicBlock * BBAfter = llvm::BasicBlock::Create(context, "for_after", function);
459
460         llvm::BasicBlock * cur_block = builder.GetInsertBlock();
461         llvm::Value * tmp;
462
463         if (known_step)
464         {
465             if (inc)
466             {
467                 tmp = use_int ? (use_uint ? builder.CreateICmpULE(start, end) : builder.CreateICmpSLE(start, end)) : builder.CreateFCmpOLE(start, end);
468             }
469             else
470             {
471                 tmp = use_int ? (use_uint ? builder.CreateICmpUGE(start, end) : builder.CreateICmpSGE(start, end)) : builder.CreateFCmpOGE(start, end);
472             }
473         }
474         else
475         {
476             //TODO: add something to handle this case
477         }
478
479         builder.CreateCondBr(tmp, BBBody, BBAfter);
480
481         builder.SetInsertPoint(BBBody);
482         llvm::PHINode * phi = use_int ? builder.CreatePHI(getLLVMTy<int64_t>(context), 2) : builder.CreatePHI(getLLVMTy<double>(context), 2);
483
484         JITSymbolMap::const_iterator i = symMap3.find(varName);
485         tmp = use_int ? (use_uint ? builder.CreateUIToFP(phi, getLLVMTy<double>(context)) : builder.CreateSIToFP(phi, getLLVMTy<double>(context))) : phi;
486         i->second.get()->store(tmp, *this);
487
488         phi->addIncoming(start, cur_block);
489
490         builder.SetInsertPoint(BBBody);
491         e.body_get().accept(*this);
492
493         tmp = use_int ? builder.CreateAdd(phi, step) : builder.CreateFAdd(phi, step);
494         phi->addIncoming(tmp, builder.GetInsertBlock());
495
496         if (known_step)
497         {
498             if (inc)
499             {
500                 tmp = use_int ? (use_uint ? builder.CreateICmpULE(tmp, end) : builder.CreateICmpSLE(tmp, end)) : builder.CreateFCmpOLE(tmp, end);
501             }
502             else
503             {
504                 tmp = use_int ? (use_uint ? builder.CreateICmpUGE(tmp, end) : builder.CreateICmpSGE(tmp, end)) : builder.CreateFCmpOGE(tmp, end);
505             }
506         }
507         else
508         {
509             //TODO: add something to handle this case
510         }
511
512         builder.CreateCondBr(tmp, BBBody, BBAfter);
513
514         builder.SetInsertPoint(BBAfter);
515
516         //llvm::AllocaInst * cur = builder.CreateAlloca(getLLVMTy<double>(context));
517         //llvm::StoreInst * cur_store = builder.CreateAlignedStore(phi, cur, sizeof(double));
518
519         //symMap3.emplace(varName, std::shared_ptr<JITVal>(new JITScalarVal<double>(*this, cur)));
520     }
521     else
522     {
523         // Should not occured...
524         // Normally, if the init is an iterator the for exp itself is not jittable
525         // but take care of the case for i=int32(1:2:123)...
526     }
527
528     //function->dump();
529 }
530
531 void JITVisitor::visit(const ast::BreakExp &e)
532 {
533
534 }
535
536 void JITVisitor::visit(const ast::ContinueExp &e)
537 {
538
539 }
540
541 void JITVisitor::visit(const ast::TryCatchExp &e)
542 {
543
544 }
545
546 void JITVisitor::visit(const ast::SelectExp &e)
547 {
548
549 }
550
551 void JITVisitor::visit(const ast::CaseExp &e)
552 {
553
554 }
555
556 void JITVisitor::visit(const ast::ReturnExp &e)
557 {
558
559 }
560
561 void JITVisitor::visit(const ast::FieldExp &e)
562 {
563
564 }
565
566 void JITVisitor::visit(const ast::NotExp &e)
567 {
568
569 }
570
571 void JITVisitor::visit(const ast::TransposeExp &e)
572 {
573
574 }
575
576 void JITVisitor::visit(const ast::MatrixExp &e)
577 {
578
579 }
580
581 void JITVisitor::visit(const ast::MatrixLineExp &e)
582 {
583
584 }
585
586 void JITVisitor::visit(const ast::CellExp &e)
587 {
588
589 }
590
591 void JITVisitor::visit(const ast::SeqExp &e)
592 {
593     for (std::list<ast::Exp *>::const_iterator i = e.exps_get().begin(), end = e.exps_get().end(); i != end; ++i)
594     {
595         result_set(std::shared_ptr<JITVal>(nullptr));
596         (*i)->accept(*this);
597     }
598 }
599
600 void JITVisitor::visit(const ast::ArrayListExp &e)
601 {
602
603 }
604
605 void JITVisitor::visit(const ast::AssignListExp &e)
606 {
607
608 }
609
610 void JITVisitor::visit(const ast::VarDec &e)
611 {
612     e.init_get().accept(*this);
613 }
614
615 void JITVisitor::visit(const ast::FunctionDec &e)
616 {
617
618 }
619
620 void JITVisitor::visit(const ast::ListExp &e)
621 {
622
623 }
624 }