2 * Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3 * Copyright (C) 2014 - Scilab Enterprises - Calixte DENIZET
5 * This file must be used under the terms of the CeCILL.
6 * This source file is licensed as described in the file COPYING, which
7 * you should have received as part of this distribution. The terms
8 * are also available at
9 * http://www.cecill.info/licences/Licence_CeCILL_V2-en.txt
13 #include "JITValues.hxx"
14 #include "JITVisitor.hxx"
15 #include "jit_operations.hxx"
19 const bool JITVisitor::__init__ = InitializeLLVM();
20 llvm::Value * const JITVisitor::ONE = llvm::ConstantInt::get(getLLVMTy<int>(), int(1));
21 llvm::Value * const JITVisitor::TWO = llvm::ConstantInt::get(getLLVMTy<int>(), int(2));
22 llvm::Value * const JITVisitor::THREE = llvm::ConstantInt::get(getLLVMTy<int>(), int(3));
23 llvm::Value * const JITVisitor::FOUR = llvm::ConstantInt::get(getLLVMTy<int>(), int(4));
25 JITVisitor::JITVisitor(const analysis::AnalysisVisitor & _analysis) : ast::ConstVisitor(),
27 context(llvm::getGlobalContext()),
28 module("JIT module", context),
29 engine(InitializeEngine(&module)),
30 FPM(initFPM(&module, engine)),
31 function(llvm::cast<llvm::Function>(module.getOrInsertFunction("jit_main", getLLVMTy<void>(context), nullptr))),
33 uintptrType(getPtrAsIntTy(module, context)),
36 module.setDataLayout(engine->getDataLayout()->getStringRepresentation());
37 llvm::BasicBlock * BB = llvm::BasicBlock::Create(context, "EntryBlock", function);
38 builder.SetInsertPoint(BB);
40 symbol::Context * ctxt = symbol::Context::getInstance();
41 const std::set<symbol::Symbol> & SymRead = analysis.get_read();
42 for (std::set<symbol::Symbol>::const_iterator it = SymRead.cbegin(), end = SymRead.cend(); it != end; ++it)
44 const std::wstring & name = it->name_get();
45 const std::string _name(name.begin(), name.end());
46 symbol::Variable * var = ctxt->getOrCreate(*it);
47 types::InternalType * pIT = symbol::Context::getInstance()->get(var);
51 symMap3.emplace(*it, std::shared_ptr<JITVal>(JITVal::get(*this, pIT, false, _name)));
55 const std::set<symbol::Symbol> & SymWrite = analysis.get_write();
56 for (std::set<symbol::Symbol>::const_iterator it = SymWrite.cbegin(), end = SymWrite.cend(); it != end; ++it)
58 const std::wstring & name = it->name_get();
59 const std::string _name(name.begin(), name.end());
60 symbol::Variable * var = ctxt->getOrCreate(*it);
61 types::InternalType * pIT = symbol::Context::getInstance()->get(var);
65 symMap3.emplace(*it, std::shared_ptr<JITVal>(JITVal::get(*this, pIT, true, _name)));
71 void JITVisitor::run()
73 // on reinjecte les resultats ds l'environnement a=1;jit("a=2");
74 symbol::Context * ctxt = symbol::Context::getInstance();
75 llvm::Value * llvmCtxt = getPointer(ctxt);
76 //llvm::Value * toCall_S = getPointer(reinterpret_cast<void *>(&jit::putInContext_S<Double, double>), getLLVMPtrFuncTy<void, char *, char *, double>(context));
77 llvm::Value * toCall_M = module.getOrInsertFunction("putInContext_M_D_ds", getLLVMFuncTy<void, char *, char *, double *, int , int>(context));
78 //llvm::Value * toCall_M = getPointer(reinterpret_cast<void *>(&jit::putInContext_M<Double, double*>), getLLVMPtrFuncTy<void, char *, char *, double *, int, int>(context));
80 llvm::Value * toCall_S = module.getOrInsertFunction("putInContext_S_D_d", getLLVMFuncTy<void, char *, char *, double>(context));
82 for (JITSymbolMap::const_iterator i = symMap3.begin(), end = symMap3.end(); i != end; ++i)
84 symbol::Variable * var = ctxt->getOrCreate(i->first);
85 llvm::Value * llvmVar = getPointer(var);
86 if (i->second.get()->is_scalar())
88 builder.CreateCall3(toCall_S, llvmCtxt, llvmVar, i->second.get()->load(*this));
92 i->second.get()->load(*this)->dump();
93 builder.CreateCall5(toCall_M, llvmCtxt, llvmVar, i->second.get()->load(*this), i->second.get()->loadR(*this), i->second.get()->loadC(*this));
97 builder.CreateRetVoid();
101 for (llvm::Module::iterator it = module.begin(), end = module.end(); it != end; ++it)
108 engine->finalizeObject();
110 //getLLVMTy<const char * const>(context)->dump();
111 //getLLVMTy<int (*)(double)>(context)->dump();
113 //foo<int (*) (double)>(context)->dump();
115 reinterpret_cast<void (*)()>(engine->getFunctionAddress("jit_main"))();
118 void JITVisitor::dump() const
125 llvm::Value * JITVisitor::getConstant<double>(const double val)
127 llvm::Value * v = llvm::ConstantFP::get(context, llvm::APFloat(val));
131 void JITVisitor::visit(const ast::SimpleVar &e)
133 /* symbol::Symbol & sym = e.name_get();
134 std::map<symbol::Symbol, llvm::Value *>::iterator i = symMap.find(sym);
135 if (i != symMap.end())
137 if (llvm::isa<llvm::AllocaInst>(i->second))
139 llvm::LoadInst * tmp = builder.CreateLoad(llvm::cast<llvm::AllocaInst>(i->second));
140 tmp->setAlignment(sizeof(double));
145 result_set(i->second);
150 std::wcout << L"que faire...=" << sym.name_get() << std::endl;
153 /* symbol::Symbol & sym = e.name_get();
154 std::map<symbol::Symbol, JITVal>::iterator i = symMap2.find(sym);
155 if (i != symMap2.end())
157 llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), 1);
158 result_set(JITVal(r, r, i->second.load(builder)));
162 types::Double * pIT = static_cast<Double *>(symbol::Context::getInstance()->get(((ast::SimpleVar&)e).stack_get()));
163 llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), pIT->getRows());
164 llvm::Value * c = llvm::Cou onstantInt::get(getLLVMTy<int>(context), pIT->getCols());
165 llvm::Value * ptr = getPointer(pIT->get(), getLLVMTy<double *>(context));
167 result_set(JITVal(r, c, ptr));
169 //std::wcout << L"que faire...=" << sym.name_get() << std::endl;
172 symbol::Symbol & sym = e.name_get();
173 JITSymbolMap::iterator i = symMap3.find(sym);
174 if (i != symMap3.end())
176 result_set(i->second);
180 const std::wstring & name = sym.name_get();
181 const std::string _name(name.begin(), name.end());
182 /*types::InternalType * pIT = symbol::Context::getInstance()->get(((ast::SimpleVar&)e).stack_get());
184 result_set(std::shared_ptr<JITVal>(JITVal::get(*this, pIT, _name)));*/
185 throw ast::ScilabError("Variable not declared before JIT: " + _name);
189 void JITVisitor::visit(const ast::DollarVar &e) //a=[1 2;3 4];b=[5 6;7 8];jit("a/b")
194 void JITVisitor::visit(const ast::ColonVar &e)
199 void JITVisitor::visit(const ast::ArrayListVar &e)
204 void JITVisitor::visit(const ast::IntExp &e)
209 void JITVisitor::visit(const ast::FloatExp &e)
214 void JITVisitor::visit(const ast::DoubleExp &e)
216 result_set(std::shared_ptr<JITVal>(new JITScalarVal<double>(*this, e.value_get(), false)));
219 void JITVisitor::visit(const ast::BoolExp &e)
224 void JITVisitor::visit(const ast::StringExp &e)
229 void JITVisitor::visit(const ast::CommentExp &e)
234 void JITVisitor::visit(const ast::NilExp &e)
239 void JITVisitor::visit(const ast::CallExp &e)
244 void JITVisitor::visit(const ast::CellCallExp &e)
249 void JITVisitor::visit(const ast::OpExp &e)
251 e.left_get().accept(*this);
252 std::shared_ptr<JITVal> pITL = result_get();
254 /*getting what to assign*/
255 e.right_get().accept(*this);
256 std::shared_ptr<JITVal> & pITR = result_get();
258 llvm::Value * pResult = NULL;
260 switch (e.oper_get())
262 case ast::OpExp::plus:
264 if (pITL.get()->is_scalar())
266 result_set(add_D_D(pITL, pITR, *this));
270 result_set(add_M_M(pITL, pITR, *this));
274 case ast::OpExp::minus:
276 if (pITL.get()->is_scalar())
278 result_set(sub_D_D(pITL, pITR, *this));
282 result_set(sub_M_M(pITL, pITR, *this));
284 return;//a=1;b=1;for i=1:2:23;a=a+i*3+b;b=b-i*a;end;
286 case ast::OpExp::times:
288 if (pITL.get()->is_scalar())
290 result_set(dotmul_D_D(pITL, pITR, *this));
294 result_set(dotmul_M_M(pITL, pITR, *this));
299 if (pITL.get()->is_scalar())
301 result_set(add_D_D(pITL, pITR, *this));
305 result_set(add_M_M(pITL, pITR, *this));
310 //llvm::Value * r = llvm::ConstantInt::get(getLLVMTy<int>(context), 1);
311 //result_set(JITVal(r, r, pResult));
314 void JITVisitor::visit(const ast::LogicalOpExp &e)
319 void JITVisitor::visit(const ast::AssignExp &e)
321 if (e.left_exp_get().is_simple_var())
323 ast::SimpleVar & pVar = static_cast<ast::SimpleVar &>(e.left_exp_get());
325 e.right_exp_get().accept(*this);
326 std::shared_ptr<JITVal> & pITR = result_get();
327 llvm::Value * alloca = nullptr;
328 JITSymbolMap::const_iterator i = symMap3.find(pVar.name_get());
330 if (i != symMap3.end())
332 i->second.get()->store(*pITR.get(), *this);
336 const std::wstring & name = pVar.name_get().name_get();
337 const std::string _name(name.begin(), name.end());
338 // TODO: virer ce truc... le param <double> est force...
339 JITVal * jitV = new JITScalarVal<double>(*this, pITR.get(), _name);
340 symMap3.emplace(pVar.name_get(), std::shared_ptr<JITVal>(jitV));
343 result_set(std::shared_ptr<JITVal>(nullptr));
347 void JITVisitor::visit(const ast::IfExp &e)
352 void JITVisitor::visit(const ast::WhileExp &e)
357 void JITVisitor::visit(const ast::ForExp &e)
359 //e.vardec_get().accept(*this);
360 const ast::VarDec & vardec = e.vardec_get();
361 symbol::Symbol & varName = vardec.name_get();
362 const ast::Exp & init = vardec.init_get();
364 if (init.is_list_exp())
366 const ast::ListExp & list = static_cast<const ast::ListExp &>(init);
367 const double * list_values = list.get_values();
368 llvm::Value * start = nullptr, * step, * end;
369 bool use_int = false;
370 bool use_uint = false;
372 bool known_step = false;
374 if (!ISNAN(list_values[0]) && !ISNAN(list_values[1]) && !ISNAN(list_values[2]))
376 const double tstart = std::trunc(list_values[0]);
377 const double tstep = std::trunc(list_values[1]);
378 const double tend = std::trunc(list_values[2]);
380 inc = list_values[1] >= 0;
383 if ((tstart == list_values[0]) && (tstep == list_values[1]))
385 if (tstart >= 0 && tstep >= 0 && tstart <= list_values[2])
387 // we can use an unsigned int but take care to overflow...
388 double k = std::floor(((double)std::numeric_limits<uint64_t>::max() - tstart) / tstep);
389 if ((k * tstep + tstart) >= tend)
392 start = getConstant((uint64_t)tstart);
393 step = getConstant((uint64_t)tstep);
394 end = getConstant((uint64_t)tend);
402 double k = std::floor(((double)(inc ? std::numeric_limits<int64_t>::max() : std::numeric_limits<int64_t>::min()) - tstart) / tstep);
403 if ((inc && (k * tstep + tstart) >= tend) || (!inc && (k * tstep + tstart) <= tend))
406 start = getConstant((int64_t)tstart);
407 step = getConstant((int64_t)tstep);
408 end = getConstant((int64_t)tend);
416 start = getConstant(list_values[0]);
417 step = getConstant(list_values[1]);
418 end = getConstant(list_values[2]);
424 if (!ISNAN(list_values[0]))
426 start = getConstant(list_values[0]);
430 list.start_get().accept(*this);
431 start = result_get().get()->load(*this);
434 if (!ISNAN(list_values[1]))
436 step = getConstant(list_values[1]);
437 inc = list_values[1] >= 0;
442 list.step_get().accept(*this);
443 step = result_get().get()->load(*this);
446 if (!ISNAN(list_values[2]))
448 end = getConstant(list_values[2]);
452 list.end_get().accept(*this);
453 end = result_get().get()->load(*this);
457 llvm::BasicBlock * BBBody = llvm::BasicBlock::Create(context, "for_body", function);
458 llvm::BasicBlock * BBAfter = llvm::BasicBlock::Create(context, "for_after", function);
460 llvm::BasicBlock * cur_block = builder.GetInsertBlock();
467 tmp = use_int ? (use_uint ? builder.CreateICmpULE(start, end) : builder.CreateICmpSLE(start, end)) : builder.CreateFCmpOLE(start, end);
471 tmp = use_int ? (use_uint ? builder.CreateICmpUGE(start, end) : builder.CreateICmpSGE(start, end)) : builder.CreateFCmpOGE(start, end);
476 //TODO: add something to handle this case
479 builder.CreateCondBr(tmp, BBBody, BBAfter);
481 builder.SetInsertPoint(BBBody);
482 llvm::PHINode * phi = use_int ? builder.CreatePHI(getLLVMTy<int64_t>(context), 2) : builder.CreatePHI(getLLVMTy<double>(context), 2);
484 JITSymbolMap::const_iterator i = symMap3.find(varName);
485 tmp = use_int ? (use_uint ? builder.CreateUIToFP(phi, getLLVMTy<double>(context)) : builder.CreateSIToFP(phi, getLLVMTy<double>(context))) : phi;
486 i->second.get()->store(tmp, *this);
488 phi->addIncoming(start, cur_block);
490 builder.SetInsertPoint(BBBody);
491 e.body_get().accept(*this);
493 tmp = use_int ? builder.CreateAdd(phi, step) : builder.CreateFAdd(phi, step);
494 phi->addIncoming(tmp, builder.GetInsertBlock());
500 tmp = use_int ? (use_uint ? builder.CreateICmpULE(tmp, end) : builder.CreateICmpSLE(tmp, end)) : builder.CreateFCmpOLE(tmp, end);
504 tmp = use_int ? (use_uint ? builder.CreateICmpUGE(tmp, end) : builder.CreateICmpSGE(tmp, end)) : builder.CreateFCmpOGE(tmp, end);
509 //TODO: add something to handle this case
512 builder.CreateCondBr(tmp, BBBody, BBAfter);
514 builder.SetInsertPoint(BBAfter);
516 //llvm::AllocaInst * cur = builder.CreateAlloca(getLLVMTy<double>(context));
517 //llvm::StoreInst * cur_store = builder.CreateAlignedStore(phi, cur, sizeof(double));
519 //symMap3.emplace(varName, std::shared_ptr<JITVal>(new JITScalarVal<double>(*this, cur)));
523 // Should not occured...
524 // Normally, if the init is an iterator the for exp itself is not jittable
525 // but take care of the case for i=int32(1:2:123)...
531 void JITVisitor::visit(const ast::BreakExp &e)
536 void JITVisitor::visit(const ast::ContinueExp &e)
541 void JITVisitor::visit(const ast::TryCatchExp &e)
546 void JITVisitor::visit(const ast::SelectExp &e)
551 void JITVisitor::visit(const ast::CaseExp &e)
556 void JITVisitor::visit(const ast::ReturnExp &e)
561 void JITVisitor::visit(const ast::FieldExp &e)
566 void JITVisitor::visit(const ast::NotExp &e)
571 void JITVisitor::visit(const ast::TransposeExp &e)
576 void JITVisitor::visit(const ast::MatrixExp &e)
581 void JITVisitor::visit(const ast::MatrixLineExp &e)
586 void JITVisitor::visit(const ast::CellExp &e)
591 void JITVisitor::visit(const ast::SeqExp &e)
593 for (std::list<ast::Exp *>::const_iterator i = e.exps_get().begin(), end = e.exps_get().end(); i != end; ++i)
595 result_set(std::shared_ptr<JITVal>(nullptr));
600 void JITVisitor::visit(const ast::ArrayListExp &e)
605 void JITVisitor::visit(const ast::AssignListExp &e)
610 void JITVisitor::visit(const ast::VarDec &e)
612 e.init_get().accept(*this);
615 void JITVisitor::visit(const ast::FunctionDec &e)
620 void JITVisitor::visit(const ast::ListExp &e)