[ast] upgrade sparse extraction after 5dc990d1 91/21491/5
St├ęphane Mottelet [Tue, 9 Jun 2020 09:09:36 +0000 (11:09 +0200)]
http://bugzilla.scilab.org/show_bug.cgi?id=14487#c7

Change-Id: I12f32a24510d6d9d3a62f990193c91bc17d95d91

scilab/modules/ast/src/cpp/types/sparse.cpp
scilab/modules/ast/tests/nonreg_tests/bug_14487.tst

index ed9e98c..fb2a5a6 100644 (file)
@@ -2040,19 +2040,42 @@ GenericType* Sparse::extract(typed_list* _pArgs)
     {
         if (piMaxDim[0] <= getSize())
         {
-            int iNewRows = 0;
-            int iNewCols = 0;
+            int iNewRows = 1;
+            int iNewCols = 1;
+            types::GenericType* pGT = (*_pArgs)[0]->getAs<GenericType>();
 
-            if (getRows() == 1 && getCols() != 1 && (*_pArgs)[0]->isColon() == false)
+            if ((*_pArgs)[0]->isColon())
+            {
+                iNewRows = piCountDim[0];
+            }
+            else if ( (!isScalar() && isVector()) && ((*_pArgs)[0]->isImplicitList() || pGT->isVector()) )
+            {
+                if (getRows() == 1)
+                {
+                    iNewCols = piCountDim[0];
+                }
+                else
+                {
+                    iNewRows = piCountDim[0];
+                }
+            }
+            else if ((*_pArgs)[0]->isImplicitList())
             {
-                //special case for row vector
-                iNewRows = 1;
                 iNewCols = piCountDim[0];
             }
             else
             {
-                iNewRows = piCountDim[0];
-                iNewCols = 1;
+                int *i_piDims = pGT->getDimsArray();
+                int i_iDims = pGT->getDims();
+                if (i_iDims > 2)
+                {
+                    iNewRows = piCountDim[0];
+                }
+                else
+                {
+                    iNewRows = i_piDims[0];
+                    iNewCols = i_piDims[1];
+                }
             }
 
             double* pIdx = pArg[0]->getAs<Double>()->get();
@@ -4092,19 +4115,41 @@ GenericType* SparseBool::extract(typed_list* _pArgs)
         // Check that we stay inside the input size.
         if (piMaxDim[0] <= getSize())
         {
-            int iNewRows = 0;
-            int iNewCols = 0;
+            int iNewRows = 1;
+            int iNewCols = 1;
 
-            if (getRows() == 1 && getCols() != 1 && (*_pArgs)[0]->isColon() == false)
+            if ((*_pArgs)[0]->isColon())
+            {
+                iNewRows = piCountDim[0];
+            }
+            else if ( (!isScalar() && isVector()) && ((*_pArgs)[0]->isImplicitList() || (*_pArgs)[0]->getAs<GenericType>()->isVector()) )
+            {
+                if (getRows() == 1)
+                {
+                    iNewCols = piCountDim[0];
+                }
+                else
+                {
+                    iNewRows = piCountDim[0];
+                }
+            }
+            else if ((*_pArgs)[0]->isImplicitList())
             {
-                //special case for row vector
-                iNewRows = 1;
                 iNewCols = piCountDim[0];
             }
             else
             {
-                iNewRows = piCountDim[0];
-                iNewCols = 1;
+                int *i_piDims = (*_pArgs)[0]->getAs<GenericType>()->getDimsArray();
+                int i_iDims = (*_pArgs)[0]->getAs<GenericType>()->getDims();
+                if (i_iDims > 2)
+                {
+                    iNewRows = piCountDim[0];
+                }
+                else
+                {
+                    iNewRows = i_piDims[0];
+                    iNewCols = i_piDims[1];
+                }
             }
 
             pOut = new SparseBool(iNewRows, iNewCols);
@@ -4596,3 +4641,4 @@ void neg(const int r, const int c, const T * const in, Eigen::SparseMatrix<bool,
     out->finalize();
 }
 }
+
index 71395bc..a1a11f6 100644 (file)
@@ -16,6 +16,8 @@
 // <-- Short Description -->
 // Matrix indexing is not coherent with MATLAB convention
 
+//full
+
 x=rand(); // scalar
 i1=ones(1,4);
 i2=ones(2,4);
@@ -60,3 +62,86 @@ assert_checkequal(size(x(i1)), [1 4]);
 assert_checkequal(size(x(i1')), [4 1]);
 assert_checkequal(size(x(i2)), [2 4]);
 assert_checkequal(size(x(i3)), [2 2 2]);
+
+//sparse
+
+x=sparse(rand());
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(1,4)); // row vector
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [1 4]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(4,1)); // column vector
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [4 1]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(3,3)); // matrix
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+
+//sparse boolean
+
+x=sparse(rand()>0.5);
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(1,4)>0.5); // row vector
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [1 4]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(4,1)>0.5); // column vector
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [4 1]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+x=sparse(rand(3,3)>0.5); // matrix
+i1=ones(1,4);
+i2=ones(2,4);
+i3=ones(2,2,2);
+assert_checkequal(size(x(i1)), [1 4]);
+assert_checkequal(size(x(i1')), [4 1]);
+assert_checkequal(size(x(i2)), [2 4]);
+assert_checkequal(size(x(i3)), [8 1]);
+
+
+
+
+
+