Merge remote-tracking branch 'origin/master' into jit
[scilab.git] / scilab / modules / ast / src / cpp / jit / JITForExp.cpp
index 9c5d2f5..8116994 100644 (file)
  *
  */
 
+#include <cstdlib>
+
 #include "JITScalars.hxx"
 #include "JITArrayofs.hxx"
 #include "JITVisitor.hxx"
 
 namespace jit
 {
-    void JITVisitor::visit(const ast::VarDec & e)
+void JITVisitor::visit(const ast::VarDec & e)
+{
+
+}
+
+void JITVisitor::visit(const ast::ForExp & e)
+{
+    const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
+    // TODO : handle for with an iterator
+    if (vd.getInit().isListExp())
     {
+        const symbol::Symbol & symIterator = vd.getSymbol();
+        const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
+
+        if (le.getDecorator().getResult().getType().type != analysis::TIType::EMPTY)
+        {
+            llvm::Value * start = nullptr;
+            llvm::Value * step = nullptr;
+            llvm::Value * end = nullptr;
+            const ast::Exp & startE = le.getStart();
+            const ast::Exp & stepE = le.getStep();
+            const ast::Exp & endE = le.getEnd();
+            const analysis::TIType & startTy = startE.getDecorator().getResult().getType();
+            const analysis::TIType & stepTy = stepE.getDecorator().getResult().getType();
+            const analysis::TIType & endTy = endE.getDecorator().getResult().getType();
+            const bool startTyIsSigned = startTy.isintegral() && startTy.issigned();
+            const bool stepTyIsSigned = stepTy.isintegral() && stepTy.issigned();
+            const bool endTyIsSigned = endTy.isintegral() && endTy.issigned();
+            bool constantStart = false;
+            bool constantStep = false;
+            bool constantEnd = false;
+            bool signStep;
+            bool integralStep = false;
+            bool integerIterator = false;
+            int64_t step_i;
+            int64_t start_i;
+            int64_t end_i;
+            double step_d;
+            double start_d;
+            double end_d;
+
+            if (stepE.isDoubleExp())
+            {
+                // Step is constant
+                constantStep = true;
+                const ast::DoubleExp & stepDE = static_cast<const ast::DoubleExp &>(stepE);
+                step_d = stepDE.getValue();
+                signStep = step_d > 0;
+                if (!signStep)
+                {
+                    step_d = -step_d;
+                }
+                if (analysis::tools::asInteger(step_d, step_i))
+                {
+                    integralStep = true;
+                }
+            }
+            else
+            {
+                // Step is not constant
+                // Check if step is zero or not
+                const analysis::TIType & ty = stepE.getDecorator().getResult().getType();
+                if (ty.isintegral() || ty.isreal())
+                {
+                    stepE.accept(*this);
+                    step = getResult()->loadData(*this);
+                }
+                else
+                {
+                    // TODO: error
+                    return;
+                }
+            }
+
+            if (startE.isDoubleExp())
+            {
+                constantStart = true;
+                const ast::DoubleExp & startDE = static_cast<const ast::DoubleExp &>(startE);
+                start_d = startDE.getValue();
+                if (analysis::tools::asInteger(start_d, start_i))
+                {
+                    integerIterator = integralStep;
+                    if (integerIterator)
+                    {
+                        start = getConstant(start_i);
+                        step = getConstant(step_i);
+                    }
+                }
+                if (!start)
+                {
+                    start = getConstant(start_d);
+                    if (!step)
+                    {
+                        step = getConstant(step_d);
+                    }
+                }
+            }
+            else
+            {
+                startE.accept(*this);
+                start = getResult()->loadData(*this);
+                if (!step)
+                {
+                    step = getConstant(step_d);
+                }
+            }
+
+            if (endE.isDoubleExp())
+            {
+                constantEnd = true;
+                const ast::DoubleExp & endDE = static_cast<const ast::DoubleExp &>(endE);
+                end_d = endDE.getValue();
+                if (analysis::tools::asInteger(end_d, end_i))
+                {
+                    end = getConstant(end_i);
+                }
+                else
+                {
+                    end = getConstant(end_d);
+                }
+            }
+            else
+            {
+                endE.accept(*this);
+                end = getResult()->loadData(*this);
+            }
+
+            if (le.getDecorator().getResult().getRange().isValid())
+            {
+                integerIterator = true;
+            }
+
+            if (integerIterator)
+            {
+                start = Cast::cast<int64_t>(start, startTyIsSigned, *this);
+                step = Cast::cast<int64_t>(step, stepTyIsSigned, *this);
+                end = Cast::cast<int64_t>(end, endTyIsSigned, *this);
+                llvm::Value * one = getConstant<int64_t>(1);
+                if (constantStep)
+                {
+                    if (signStep)
+                    {
+                        end = builder.CreateAdd(end, one);
+                    }
+                    else
+                    {
+                        end = builder.CreateSub(end, one);
+                    }
+                }
+            }
+            else
+            {
+                start = Cast::cast<double>(start, startTyIsSigned, *this);
+                step = Cast::cast<double>(step, stepTyIsSigned, *this);
+                end = Cast::cast<double>(end, endTyIsSigned, *this);
+            }
+
+            // integers values: for i = start:step:end...
+            // is equivalent to for (int64_t i = start; i < end + 1; i += step)...
+
+            // flaoting values: for i = start:step:end...
+            // is equivalent to for (double i = start; i <= end; i += step)...
+
+            llvm::BasicBlock * cond = llvm::BasicBlock::Create(context, "for_cond", function);
+            llvm::BasicBlock * loop = llvm::BasicBlock::Create(context, "for_loop", function);
+            llvm::BasicBlock * after = llvm::BasicBlock::Create(context, "for_after", function);
+
+            // To know how to break or continue the loop
+            blocks.emplace(cond, after);
+
+            llvm::Value * sign_step = nullptr;
+            llvm::Value * zero_i64 = getConstant<int64_t>(0);
+            llvm::Value * zero_dbl = getConstant<double>(0);
+            llvm::Value * abs = llvm::Intrinsic::getDeclaration(&getModule(), llvm::Intrinsic::fabs, getTy<double>());
+            llvm::Value * abs_arg[1];
+            llvm::Value * cmp_i1 = nullptr;
+
+            if (!constantStep)
+            {
+                llvm::Value * cmp_zero;
+                if (integerIterator)
+                {
+                    cmp_zero = builder.CreateICmpEQ(step, zero_i64);
+                }
+                else
+                {
+                    cmp_zero = builder.CreateFCmpOEQ(step, zero_dbl);
+                }
+                llvm::BasicBlock * precond = llvm::BasicBlock::Create(context, "", function);
+                builder.CreateCondBr(cmp_zero, after, precond);
+                builder.SetInsertPoint(precond);
+            }
+
+            if (!integerIterator)
+            {
+                if (!constantStart)
+                {
+                    abs_arg[0] = start;
+                    llvm::CallInst * abs_start = builder.CreateCall(abs, abs_arg);
+                    abs_start->setTailCall(true);
+                    cmp_i1 = builder.CreateFCmpUEQ(abs_start, getConstant<double>(std::numeric_limits<double>::infinity()));
+                    llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
+                    builder.CreateCondBr(cmp_i1, after, bb);
+                    builder.SetInsertPoint(bb);
+                }
 
+                if (!constantEnd)
+                {
+                    abs_arg[0] = end;
+                    llvm::CallInst * abs_end = builder.CreateCall(abs, abs_arg);
+                    abs_end->setTailCall(true);
+                    cmp_i1 = builder.CreateFCmpUEQ(abs_end, getConstant<double>(std::numeric_limits<double>::infinity()));
+                    llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
+                    builder.CreateCondBr(cmp_i1, after, bb);
+                    builder.SetInsertPoint(bb);
+                }
+            }
+
+            if (integerIterator)
+            {
+                if (constantStep)
+                {
+                    if (signStep)
+                    {
+                        cmp_i1 = builder.CreateICmpSLT(start, end);
+                    }
+                    else
+                    {
+                        cmp_i1 = builder.CreateICmpSGT(start, end);
+                    }
+                }
+                else
+                {
+                    sign_step = builder.CreateICmpSGT(step, zero_i64);
+                    llvm::Value * cmp1 = builder.CreateICmpSLT(start, end);
+                    llvm::Value * cmp2 = builder.CreateICmpSGT(start, end);
+                    cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
+                }
+            }
+            else
+            {
+                if (constantStep)
+                {
+                    if (signStep)
+                    {
+                        cmp_i1 = builder.CreateFCmpOLE(start, end);
+                    }
+                    else
+                    {
+                        cmp_i1 = builder.CreateFCmpOGE(start, end);
+                    }
+                }
+                else
+                {
+                    abs_arg[0] = step;
+                    llvm::CallInst * abs_step = builder.CreateCall(abs, abs_arg);
+                    abs_step->setTailCall(true);
+                    cmp_i1 = builder.CreateFCmpUEQ(abs_step, getConstant<double>(std::numeric_limits<double>::infinity()));
+                    llvm::BasicBlock * bb = llvm::BasicBlock::Create(context, "", function);
+                    builder.CreateCondBr(cmp_i1, after, bb);
+                    builder.SetInsertPoint(bb);
+                    sign_step = builder.CreateFCmpOGT(step, zero_dbl);
+                    llvm::Value * cmp1 = builder.CreateFCmpOLE(start, end);
+                    llvm::Value * cmp2 = builder.CreateFCmpOGE(start, end);
+                    cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
+                }
+            }
+
+            bool hasPre = false;
+            if (const analysis::LoopDecoration * ld = e.getDecorator().getLoopDecoration())
+            {
+                if (!ld->getClone().empty() || !ld->getPromotion().empty())
+                {
+                    llvm::BasicBlock * pre = llvm::BasicBlock::Create(context, "for_pre", function);
+                    builder.CreateCondBr(cmp_i1, pre, after);
+                    builder.SetInsertPoint(pre);
+                    if (!ld->getClone().empty())
+                    {
+                        cloneSyms(e);
+                    }
+                    else
+                    {
+                        //promoteSyms(e);
+                    }
+                    builder.CreateBr(loop);
+                    hasPre = true;
+                }
+            }
+
+            if (!hasPre)
+            {
+                builder.CreateCondBr(cmp_i1, loop, after);
+            }
+
+            llvm::BasicBlock * cur_block = builder.GetInsertBlock();
+
+            builder.SetInsertPoint(loop);
+            llvm::PHINode * i;
+            if (integerIterator)
+            {
+                i = builder.CreatePHI(getTy<int64_t>(), 2);
+            }
+            else
+            {
+                i = builder.CreatePHI(getTy<double>(), 2);
+            }
+
+            i->addIncoming(start, cur_block);
+            JITScilabPtr & it = variables.find(symIterator)->second;
+            it->storeData(*this, i);
+
+            // Visit the loop body
+            e.getBody().accept(*this);
+
+            builder.CreateBr(cond);
+
+            builder.SetInsertPoint(cond);
+            llvm::Value * i_step;
+            if (integerIterator)
+            {
+                if (constantStep)
+                {
+                    if (signStep)
+                    {
+                        i_step = builder.CreateAdd(i, step);
+                        cmp_i1 = builder.CreateICmpSLT(i_step, end);
+                    }
+                    else
+                    {
+                        i_step = builder.CreateSub(i, step);
+                        cmp_i1 = builder.CreateICmpSGT(i_step, end);
+                    }
+                }
+                else
+                {
+                    i_step = builder.CreateAdd(i, step);
+                    llvm::Value * cmp1 = builder.CreateICmpSLT(i_step, end);
+                    llvm::Value * cmp2 = builder.CreateICmpSGT(i_step, end);
+                    cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
+                }
+            }
+            else
+            {
+                if (constantStep)
+                {
+                    if (signStep)
+                    {
+                        i_step = builder.CreateFAdd(i, step);
+                        cmp_i1 = builder.CreateFCmpOLE(i_step, end);
+                    }
+                    else
+                    {
+                        i_step = builder.CreateFSub(i, step);
+                        cmp_i1 = builder.CreateFCmpOGE(i_step, end);
+                    }
+                }
+                else
+                {
+                    i_step = builder.CreateFAdd(i, step);
+                    llvm::Value * cmp1 = builder.CreateFCmpOLE(i_step, end);
+                    llvm::Value * cmp2 = builder.CreateFCmpOGE(i_step, end);
+                    cmp_i1 = builder.CreateSelect(sign_step, cmp1, cmp2);
+                }
+            }
+
+            i->addIncoming(i_step, cond);
+
+            builder.CreateCondBr(cmp_i1, loop, after);
+
+            builder.SetInsertPoint(after);
+        }
     }
+}
+
+void JITVisitor::cloneSyms(const ast::Exp & e)
+{
+    llvm::Type * types[] = { getTy<int8_t *>(), getTy<int8_t *>(), getTy<int64_t>(), getTy<int32_t>(), getTy<bool>() };
+    llvm::Value * __memcpy = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::memcpy, types);
+    llvm::Function * __new = static_cast<llvm::Function *>(module->getOrInsertFunction("new", llvm::FunctionType::get(getTy<int8_t *>(), llvm::ArrayRef<llvm::Type *>(getTy<uint64_t>()), false)));
 
-    void JITVisitor::visit(const ast::ForExp & e)
+    for (const auto & sym : e.getDecorator().getLoopDecoration()->getClone())
     {
-       const ast::VarDec & vd = static_cast<const ast::VarDec &>(e.getVardec());
-       if (vd.getInit().isListExp())
-       {
-           const symbol::Symbol & symIterator = vd.getSymbol();
-           const ast::ListExp & le = static_cast<const ast::ListExp &>(vd.getInit());
-           if (le.getDecorator().getResult().getRange().isValid())
-           {
-               // for i = start:step:end...
-               // is equivalent to for (int64_t i = start; i < end + step; i += step)...
-               
-               le.getStart().accept(*this);
-               llvm::Value * start = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
-               le.getStep().accept(*this);
-               llvm::Value * step = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
-               le.getEnd().accept(*this);
-               llvm::Value * end = Cast::cast<int64_t>(getResult()->loadData(*this), false, *this);
-               end = builder.CreateAdd(end, step);
-
-               llvm::BasicBlock * cur_block = builder.GetInsertBlock();
-               llvm::BasicBlock * condBlock = llvm::BasicBlock::Create(context, "for_cond", function);
-               llvm::BasicBlock * loopBlock = llvm::BasicBlock::Create(context, "for_loop", function);
-               llvm::BasicBlock * afterBlock = llvm::BasicBlock::Create(context, "for_after", function);
-
-               blocks.emplace(condBlock, afterBlock);
-               
-               llvm::Value * cmp_i1 = builder.CreateICmpSLT(start, end);
-               builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
-
-               
-               builder.SetInsertPoint(loopBlock);
-               llvm::PHINode * i = builder.CreatePHI(getTy<int64_t>(), 2);
-               i->addIncoming(start, cur_block);
-               JITScilabPtr & it = variables.find(symIterator)->second;
-               it->storeData(*this, i);
-
-               e.getBody().accept(*this);
-               builder.CreateBr(condBlock);
-
-               builder.SetInsertPoint(condBlock);
-               llvm::Value * ipstp_i64 = builder.CreateAdd(i, step);
-               i->addIncoming(ipstp_i64, condBlock);
-               cmp_i1 = builder.CreateICmpSLT(ipstp_i64, end);
-               builder.CreateCondBr(cmp_i1, loopBlock, afterBlock);
-
-               builder.SetInsertPoint(afterBlock);
-/*
-               llvm::Value * cmp_i1 = builder.CreateICmpSLT(start, end);
-               builder.CreateCondBr(cmp_i1, BBBody, BBAfter);
-
-               builder.SetInsertPoint(BBBody);
-               llvm::PHINode * i = builder.CreatePHI(getTy<int64_t>(), 2);
-               i->addIncoming(start, cur_block);
-               JITScilabPtr & it = variables.find(symIterator)->second;
-               it->storeData(*this, i);
-
-               e.getBody().accept(*this);
-
-               if (builder.GetInsertBlock() != BBAfter)
-               {
-                   BBAfter->moveAfter(builder.GetInsertBlock());
-               }
-               
-               llvm::Value * ipstp_i64 = builder.CreateAdd(i, step);
-               i->addIncoming(ipstp_i64, builder.GetInsertBlock());
-               cmp_i1 = builder.CreateICmpSLT(ipstp_i64, end);
-               builder.CreateCondBr(cmp_i1, BBBody, BBAfter);
-               builder.SetInsertPoint(BBAfter);
-*/
-           }
-           
-       }
-
-       
-       // e.getBody().accept(*this);
+        // TODO: bad stuff here: we must add the mangling to the name !!
+        JITScilabPtr & ptr = variables.find(sym)->second;
+        llvm::Value * x = ptr->loadData(*this);
+        llvm::Value * r = ptr->loadRows(*this);
+        llvm::Value * c = ptr->loadCols(*this);
+        llvm::Value * rc = builder.CreateMul(r, c);
+        llvm::Value * size = builder.CreateMul(rc, getConstant<int64_t>(getTySizeInBytes(x)));
+        llvm::CallInst * dest = builder.CreateCall(__new, size);
+        dest->addAttribute(0, llvm::Attribute::NoAlias);
+        llvm::Value * src = builder.CreateBitCast(x, getTy<int8_t *>());
+        llvm::Value * memcpy_args[] = { dest, src, size, getConstant<int64_t>(getTySizeInBytes(x)), getBool(false) };
+        builder.CreateCall(__memcpy, memcpy_args);
+        ptr->storeData(*this, dest);
     }
 
+    /*    if (const analysis::Clone * cl = e.getDecorator().getClone())
+        {
+            if (!cl->get().empty())
+            {
+                llvm::Type * types[] = { getTy<int8_t *>(), getTy<int8_t *>(), getTy<int64_t>(), getTy<int32_t>(), getTy<bool>() };
+                llvm::Value * __memcpy = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::memcpy, types);
+                llvm::Function * __new = static_cast<llvm::Function *>(module->getOrInsertFunction("new", llvm::FunctionType::get(getTy<int8_t *>(), llvm::ArrayRef<llvm::Type *>(getTy<uint64_t>()), false)));
+                __new->addAttribute(0, llvm::Attribute::NoAlias);
+
+                for (const auto & sym : cl->get())
+                {
+                    JITScilabPtr & ptr = variables.find(sym)->second;
+                    llvm::Value * x = ptr->loadData(*this);
+                    llvm::Value * r = ptr->loadRows(*this);
+                    llvm::Value * c = ptr->loadCols(*this);
+                    llvm::Value * rc = builder.CreateMul(r, c);
+                    llvm::Value * size = builder.CreateMul(rc, getConstant<int64_t>(getTySizeInBytes(x)));
+                    llvm::CallInst * dest = builder.CreateCall(__new, size);
+                    dest->addAttribute(0, llvm::Attribute::NoAlias);
+                    llvm::Value * src = builder.CreateBitCast(x, getTy<int8_t *>());
+                    llvm::Value * memcpy_args[] = { dest, src, size, getConstant<int64_t>(getTySizeInBytes(x)), getBool(false) };
+                    builder.CreateCall(__memcpy, memcpy_args);
+                    ptr->storeData(*this, dest);
+                }
+            }
+       }*/
+
+}
 }