JIT: don't try to store a double in an uint location 17/15217/1
Calixte DENIZET [Fri, 12 Sep 2014 15:17:49 +0000 (17:17 +0200)]
Change-Id: Ia9794f4a40f73cf9e7dc356edd281d0f72ebf9b4

scilab/modules/ast/includes/analysis/AnalysisVisitor.hxx
scilab/modules/ast/includes/analysis/ForList.hxx
scilab/modules/ast/includes/analysis/SymInfo.hxx
scilab/modules/ast/includes/analysis/TIType.hxx
scilab/modules/ast/includes/exps/ast.hxx
scilab/modules/ast/includes/exps/vardec.hxx
scilab/modules/ast/includes/jit/JITValues.hxx
scilab/modules/ast/src/cpp/jit/JITVisitor.cpp
scilab/modules/ast/src/cpp/jit/jit_operations.cpp
scilab/modules/functions/sci_gateway/cpp/sci_jit.cpp

index 3870653..bdbfac6 100644 (file)
 #define __ANALYSIS_VISITOR_HXX__
 
 #include <algorithm>
-#include <memory>
+#include <chrono>
 #include <limits>
 #include <map>
+#include <memory>
 
 #include "visitor.hxx"
 #include "allexp.hxx"
@@ -43,10 +44,14 @@ private:
     unsigned int scalars_tmp[TIType::COUNT][2];
     unsigned int arrays_tmp[TIType::COUNT][2];
 
+    std::chrono::steady_clock::time_point start;
+    std::chrono::steady_clock::time_point end;
+
 public:
 
     AnalysisVisitor()
     {
+        start_chrono();
         std::fill(&scalars_tmp[0][0], &scalars_tmp[0][0] + TIType::COUNT * 2, 0);
         std::fill(&arrays_tmp[0][0], &arrays_tmp[0][0] + TIType::COUNT * 2, 0);
     }
@@ -57,9 +62,12 @@ public:
     }
 
     // Only for debug use
-    inline void print_info() const
+    inline void print_info()
     {
-        std::wcout << L"Scalars:" << std::endl;
+        stop_chrono();
+        std::wcout << L"Analysis duration: " << get_duration() << L" s" << std::endl;
+
+        std::wcout << L"Temporary scalars:" << std::endl;
         for (unsigned int i = 0; i < TIType::COUNT; ++i)
         {
             if (scalars_tmp[i][0] || scalars_tmp[i][1])
@@ -70,7 +78,7 @@ public:
 
         std::wcout << std::endl;
 
-        std::wcout << L"Arrays:" << std::endl;
+        std::wcout << L"Temporary arrays:" << std::endl;
         for (unsigned int i = 0; i < TIType::COUNT; ++i)
         {
             if (arrays_tmp[i][0] || arrays_tmp[i][1])
@@ -89,6 +97,21 @@ public:
         std::wcout << std::endl;
     }
 
+    void start_chrono()
+    {
+        start = std::chrono::steady_clock::now();
+    }
+
+    void stop_chrono()
+    {
+        end = std::chrono::steady_clock::now();
+    }
+
+    double get_duration() const
+    {
+        return (double)std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() * 1e-9d;
+    }
+
 private:
 
     inline void add_tmp(const TIType & t, const int n = 1, const bool scalar = false)
@@ -430,6 +453,12 @@ private:
     {
         e.vardec_get().accept(*this);
         e.body_get().accept(*this);
+
+        MapSymInfo::const_iterator it = symsinfo.find(e.vardec_get().name_get());
+        if (it->second.read)
+        {
+            e.vardec_get().list_info_get().set_read_in_loop(true);
+        }
     }
 
     void visit(ast::BreakExp & e)
index ab72a41..1778218 100644 (file)
@@ -32,6 +32,7 @@ template<typename T>
 class ForList
 {
     bool constant;
+    bool read_in_loop;
     double min;
     double step;
     double max;
@@ -42,9 +43,9 @@ class ForList
 
 public:
 
-    ForList() : constant(false) { }
+    ForList() : constant(false), read_in_loop(false) { }
 
-    ForList(const double m, const double s, const double M) : constant(true), min(m), step(s), max(M)
+    ForList(const double m, const double s, const double M) : constant(true), read_in_loop(false), min(m), step(s), max(M)
     {
         if (!isempty())
         {
@@ -71,71 +72,83 @@ public:
         }
     }
 
-    bool is_constant() const
+    inline bool is_read_in_loop() const
+    {
+        return read_in_loop;
+    }
+
+    inline void set_read_in_loop(const bool read)
+    {
+        read_in_loop = read;
+    }
+
+    inline bool is_constant() const
     {
         return constant;
     }
 
-    bool is_int() const
+    inline bool is_int() const
     {
         return _int;
     }
 
-    bool is_uint() const
+    inline bool is_uint() const
     {
         return _unsigned;
     }
 
     template<typename U>
-    U get_min() const
+    inline U get_min() const
     {
         return std::is_integral<U>::value ? TRUNC(min) : min;
     }
 
     template<typename U>
-    U get_step() const
+    inline U get_step() const
     {
         return std::is_integral<U>::value ? TRUNC(step) : step;
     }
 
     template<typename U>
-    U get_max() const
+    inline U get_max() const
     {
         return std::is_integral<U>::value ? TRUNC(max) : max;
     }
 
-    TIType get_type() const
+    inline TIType get_type() const
     {
-        if (isempty())
-        {
-            return TIType(TIType::EMPTY);
-        }
-
-        if (is_int())
-        {
-            if (is_uint())
+        /*
+            if (isempty())
             {
-                if (std::is_same<T, int32_t>::value)
-                {
-                    return TIType(TIType::UINT32);
-                }
-                else
-                {
-                    return TIType(TIType::UINT64);
-                }
+                return TIType(TIType::EMPTY);
             }
-            else
+
+            if (is_int())
             {
-                if (std::is_same<T, int64_t>::value)
+                if (is_uint())
                 {
-                    return TIType(TIType::INT32);
+                    if (std::is_same<T, int32_t>::value)
+                    {
+                        return TIType(TIType::UINT32);
+                    }
+                    else
+                    {
+                        return TIType(TIType::UINT64);
+                    }
                 }
                 else
                 {
-                    return TIType(TIType::INT64);
+                    if (std::is_same<T, int64_t>::value)
+                    {
+                        return TIType(TIType::INT32);
+                    }
+                    else
+                    {
+                        return TIType(TIType::INT64);
+                    }
                 }
             }
-        }
+        */
 
         return TIType(TIType::DOUBLE);
     }
index 9017d3d..55f9eb7 100644 (file)
@@ -71,6 +71,11 @@ struct SymInfo
         set(k2);
     }
 
+    inline bool is_just_read() const
+    {
+        return !write && !replace && read;
+    }
+
     friend std::wostream & operator<<(std::wostream & out, const SymInfo & si)
     {
         out << L"Symbol use { w: " << (si.write ? L'T' : L'F')
index 403c425..720fa77 100644 (file)
@@ -32,7 +32,7 @@ struct TIType
 
     inline bool isscalar() const
     {
-        return type != EMPTY && rows == 1 && cols == 1;
+        return rows == 1 && cols == 1;
     }
 
     inline bool isknown() const
@@ -45,6 +45,56 @@ struct TIType
         return type == r.type && rows == r.rows && cols == r.cols;
     }
 
+    inline std::string get_mangling() const
+    {
+        const bool sc = rows == 1 && cols == 1;
+        switch (type)
+        {
+            case EMPTY :
+                return "E";
+            case BOOLEAN :
+                return sc ? "S_b" : "M_b";
+            case COMPLEX :
+                return sc ? "S_c" : "M_c";
+            case DOUBLE :
+                return sc ? "S_d" : "M_d";
+            case INT16 :
+                return sc ? "S_i16" : "M_i16";
+            case INT32 :
+                return sc ? "S_i32" : "M_i32";
+            case INT64 :
+                return sc ? "S_i64" : "M_i64";
+            case INT8 :
+                return sc ? "S_i8" : "M_i8";
+            case POLYNOMIAL :
+                return sc ? "S_p" : "M_p";
+            case STRING :
+                return sc ? "S_s" : "M_s";
+            case SPARSE :
+                return sc ? "S_sp" : "M_sp";
+            case UINT16 :
+                return sc ? "S_ui16" : "M_ui16";
+            case UINT32 :
+                return sc ? "S_ui32" : "M_ui32";
+            case UINT64 :
+                return sc ? "S_ui64" : "M_ui64";
+            case UINT8 :
+                return sc ? "S_ui8" : "M_ui8";
+            default :
+                return "??";
+        }
+    }
+
+    inline static std::string get_unary_mangling(const std::string & pre, const TIType & l)
+    {
+        return pre + "_" + l.get_mangling();
+    }
+
+    inline static std::string get_binary_mangling(const std::string & pre, const TIType & l, const TIType & r)
+    {
+        return pre + "_" + l.get_mangling() + "_" + r.get_mangling();
+    }
+
     friend std::wostream & operator<<(std::wostream & out, const TIType & type)
     {
         switch (type.type)
@@ -122,4 +172,4 @@ struct hash<analysis::TIType>
 };
 } // namespace std
 
-#endif // __TITYPE_HXX__
+#endif // __TITYPE_HXX__
\ No newline at end of file
index 6ed00a4..c6c92d4 100644 (file)
@@ -97,7 +97,7 @@ public:
         nodeNumber = _nodeNumber;
     }
 
-    analysis::Decorator decorator_get() const
+    const analysis::Decorator & decorator_get() const
     {
         return decorator;
     }
index 225b33d..31f7305 100644 (file)
@@ -123,7 +123,7 @@ public:
         return list_info;
     }
 
-    analysis::ForList64 & ForList_get()
+    analysis::ForList64 & list_info_get()
     {
         return list_info;
     }
index 26234d8..efcc567 100644 (file)
@@ -50,6 +50,7 @@ public:
     {
         return val;
     }
+
     inline llvm::Type * get_type() const
     {
         return type;
index 3ab0daf..50bb6e6 100644 (file)
@@ -168,6 +168,9 @@ JITVisitor::JITVisitor(const analysis::AnalysisVisitor & _analysis) : ast::Const
             }
         }
     }
+
+
+
 }
 
 void JITVisitor::run()
@@ -177,18 +180,24 @@ void JITVisitor::run()
     llvm::Value * llvmCtxt = getPointer(ctxt);
     llvm::Value * toCall_M = module.getOrInsertFunction("putInContext_M_D_ds", getLLVMFuncTy<void, char *, char *, double *, int , int>(context));
     llvm::Value * toCall_S = module.getOrInsertFunction("putInContext_S_D_d", getLLVMFuncTy<void, char *, char *, double>(context));
+    const analysis::AnalysisVisitor::MapSymInfo & info = analysis.get_infos();
 
     for (JITSymbolMap::const_iterator i = symMap3.begin(), end = symMap3.end(); i != end; ++i)
     {
-        symbol::Variable * var = ctxt->getOrCreate(i->first);
-        llvm::Value * llvmVar = getPointer(var);
-        if (i->second.get()->is_scalar())
+        analysis::AnalysisVisitor::MapSymInfo::const_iterator it = info.find(i->first);
+        if (it != info.end() && !it->second.is_just_read())
         {
-            builder.CreateCall3(toCall_S, llvmCtxt, llvmVar, i->second.get()->load(*this));
-        }
-        else
-        {
-            builder.CreateCall5(toCall_M, llvmCtxt, llvmVar, i->second.get()->load(*this), i->second.get()->loadR(*this), i->second.get()->loadC(*this));
+            std::wcout << L"push in context: " << i->first.name_get() << std::endl;
+            symbol::Variable * var = ctxt->getOrCreate(i->first);
+            llvm::Value * llvmVar = getPointer(var);
+            if (i->second.get()->is_scalar())
+            {
+                builder.CreateCall3(toCall_S, llvmCtxt, llvmVar, i->second.get()->load(*this));
+            }
+            else
+            {
+                builder.CreateCall5(toCall_M, llvmCtxt, llvmVar, i->second.get()->load(*this), i->second.get()->loadR(*this), i->second.get()->loadC(*this));
+            }
         }
     }
 
@@ -344,7 +353,6 @@ void JITVisitor::visit(const ast::OpExp &e)
     e.left_get().accept(*this);
     std::shared_ptr<JITVal> pITL = result_get();
 
-    /*getting what to assign*/
     e.right_get().accept(*this);
     std::shared_ptr<JITVal> & pITR = result_get();
 
@@ -354,7 +362,10 @@ void JITVisitor::visit(const ast::OpExp &e)
     {
         case ast::OpExp::plus:
         {
-            if (pITL.get()->is_scalar())
+            //const analysis::TIType & LT = e.left_get().decorator_get().res.get_type();
+            //const analysis::TIType & RT = e.right_get().decorator_get().res.get_type();
+
+            if (pITL.get()->is_scalar() && pITR.get()->is_scalar())
             {
                 result_set(add_D_D(pITL, pITR, *this));
             }
@@ -464,6 +475,7 @@ void JITVisitor::visit(const ast::ForExp &e)
         bool use_uint = false;
         bool inc = true;
         bool known_step = false;
+        bool it_read_in_loop = list_info.is_read_in_loop();
 
         if (list_info.is_constant())
         {
@@ -559,9 +571,12 @@ void JITVisitor::visit(const ast::ForExp &e)
 
         // TODO: the call to uitofp is not removed even if it use mainly useless...
         // a=1;b=1;jit("for i=1:21;c=a+b;a=b;b=c;end;")
-        JITSymbolMap::const_iterator i = symMap3.find(varName);
-        tmp = use_int ? (use_uint ? builder.CreateUIToFP(phi, getLLVMTy<double>(context)) : builder.CreateSIToFP(phi, getLLVMTy<double>(context))) : phi;
-        i->second.get()->store(tmp, *this);
+        if (it_read_in_loop)
+        {
+            JITSymbolMap::const_iterator i = symMap3.find(varName);
+            tmp = use_int ? (use_uint ? builder.CreateUIToFP(phi, getLLVMTy<double>(context)) : builder.CreateSIToFP(phi, getLLVMTy<double>(context))) : phi;
+            i->second.get()->store(tmp, *this);
+        }
 
         phi->addIncoming(start, cur_block);
 
index 83bb6ca..4328269 100644 (file)
@@ -61,6 +61,32 @@ std::shared_ptr<JITVal> add_M_M(std::shared_ptr<JITVal> & L, std::shared_ptr<JIT
     return std::shared_ptr<JITVal>(new JITMatrixVal<double>(visitor, L.get()->getR(), L.get()->getC(), alloc));
 }
 
+std::shared_ptr<JITVal> add_M_M_1(std::shared_ptr<JITVal> & L, std::shared_ptr<JITVal> & R, JITVisitor & visitor)
+{
+    // TODO: voir comment on peut avec l'analyzer recuperer une info sur le LHS de l'assignement pr eviter de reallouer
+    // un tableau a chaque fois et dc reutiliser le precedent... (pbs potentiels avec truc du genre a=a*b mais pas avec a=a+b)
+
+    llvm::LLVMContext & context = visitor.getContext();
+    llvm::IRBuilder<> & builder = visitor.getBuilder();
+    llvm::Value * size = builder.CreateMul(L.get()->loadR(visitor), L.get()->loadC(visitor));
+    llvm::Value * eight = llvm::ConstantInt::get(getLLVMTy<int>(context), int(8));
+    llvm::Value * malloc_size = builder.CreateMul(size, eight);
+    llvm::BasicBlock * cur_block = builder.GetInsertBlock();
+    llvm::Value * alloc = llvm::CallInst::CreateMalloc(cur_block, getLLVMTy<int>(context), getLLVMTy<double>(context), malloc_size);
+    cur_block->getInstList().push_back(llvm::cast<llvm::Instruction>(alloc));
+
+    llvm::Value * toCall = visitor.getModule().getOrInsertFunction("add_M_M_d_d", getLLVMFuncTy<void, double *, long long, double *, double *>(context));
+
+
+    size = builder.CreateIntCast(size, getLLVMTy<long long>(context), false);
+    builder.CreateCall4(toCall, L.get()->load(visitor), size, R.get()->load(visitor), alloc);
+
+    //llvm::Value * toCall_debug = visitor.getPointer(reinterpret_cast<void *>(&debug), getLLVMPtrFuncTy<void, double *>(context));
+    //builder.CreateCall(toCall_debug, alloc);
+
+    return std::shared_ptr<JITVal>(new JITMatrixVal<double>(visitor, L.get()->getR(), L.get()->getC(), alloc));
+}
+
 std::shared_ptr<JITVal> sub_M_M(std::shared_ptr<JITVal> & L, std::shared_ptr<JITVal> & R, JITVisitor & visitor)
 {
     llvm::LLVMContext & context = visitor.getContext();
index 50b7d84..f7be592 100644 (file)
 
 extern "C"
 {
-#include "sci_malloc.h"
-#include "os_wcsicmp.h"
 #include "Scierror.h"
-#include "sciprint.h"
 #include "localization.h"
-#include "os_swprintf.h"
 }
 
 using namespace std;
@@ -101,5 +97,5 @@ Function::ReturnValue sci_jit(types::typed_list &in, int _iRetCount, types::type
 
     delete pExp;
 
-    return Function::OK;
+    return Function::OK;//a=[1 2;3 4];b=[5 6; 7 8];jit("for i=1:21;c=a+b;a=b;b=c;end;"),a,b
 }