Samuel Gougeon [Fri, 20 Sep 2013 08:52:52 +0000 (10:52 +0200)]
permute() is much faster when it uses sub2ind().

This change comes from http://bugzilla.scilab.org/show_bug.cgi?id=5205

Change-Id: I1bb210dc05d670d5fd90953f1af75a561d6026e3

index e22ed9b..98be866 100644 (file)
@@ -1,5 +1,6 @@
// Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
// Copyright (C) INRIA - Farid BELAHCENE
+// Copyright (C) 2013 - Samuel GOUGEON : processing rewritten, fixing http://bugzilla.scilab.org/5205
//
// This file must be used under the terms of the CeCILL.
// This source file is licensed as described in the file COPYING, which
@@ -38,60 +39,37 @@ function y = permute(x, dims)
return
end

-    // xsize vector contains the size of x
-    xsize = size(x)
-    // ysize vector contains the new size of x after the permutation
-    ind1 = find(dims<=ndims(x))
-    ind2 = find(dims>ndims(x))
-    ysize(ind1) = xsize(dims(ind1))
-    ysize(ind2) = 1
-    dims = dims(ind1)
-
-    // Delete the last dimensions of ysize which are equal to 1, ex : [2,3,1,4,1,1,1] -> [2,3,1,4]
-    i = prod(size(ysize))
-    while i>2 & ysize(i)==1 & i>max(ind1)
-        ysize(i) = []
-        i = i-1
+    // ---------------- PROCESSING --------------------
+    // Existing indices
+    s = size(x)
+    p = size(x, "*")
+    n = 1
+    for i = 1:length(s)
+        t = "x%d = ones(1,p/(prod(s(1:%d)))) .*. ((1:s(%d)) .*. ones(1,n)) ;"+..
+        " n = prod(s(1:%d))\n"
+        t = msprintf(t, i, i, i, i)
+        execstr(t)
end
+    xlist = strcat(msprintf("x%d\n",(1:length(s))'),",")
+    cstr = "sub2ind(s,"+ xlist +")"
+    execstr("LI = "+cstr)

-    // index vector contains all indices of x
-    index = zeros(1, prod(xsize)*length(xsize)); // Preallocate index
-    m = 1; // Iterator on index
-    for k=1:size(xsize,"*")
-        for j=1:size(x,"*")/prod(xsize(1:k))
-            for l=1:xsize(k)
-                temp = ones(1, prod(xsize(1:k-1)))*l;
-                index(1, m:m+size(temp,"*")-1) = temp;
-                m = m+size(temp, "*"); // Pad m with the size of the vector we just computed
-            end
-        end
-    end
-    index = matrix(index, size(x, "*"), size(xsize, "*"))
+    // New indices
+    s = s(dims)
+    cstr = "sub2ind(s,"+ strcat(msprintf("x%d\n", dims(:)), ",")+")"
+    execstr("LI2 = "+cstr)

-    // prodxsize is a vector, its ith component contains the prod of the first to the (ith-1) entries of xsize, its first component is always equal to one
-    prodxsize = ones(size(xsize, "*"), 1)
-    for i=2:size(xsize,"*")
-        prodxsize(i) = prod(xsize(1:i-1))
-    end
-    prodysize = ones(size(xsize, "*"), 1)
-    for i=2:size(ysize,"*")
-        prodysize(i) = prod(ysize(1:i-1))
-    end
+    // Clearing intermediate memory used
+    execstr("clear "+strsubst(xlist, ",", " "))

-    // newindex contains the indices of x dimensions permutation
-    for j=1:size(index,1)
-        indexj = index(j, :)
-        newindexj = ones(1:prod(size(ysize)))
-        newindexj(ind1) = indexj(dims)
-        indexj(2:\$) = indexj(2:\$)-1
-        newindexj(2:\$) = newindexj(2:\$)-1
-        if typeof(x) == "ce" then // Case x is a cell array
-            y(newindexj*prodysize).entries = x(indexj*prodxsize).entries
-        else
-            y(newindexj*prodysize) = x(indexj*prodxsize)
-        end
+    // Permutation
+    if typeof(x) == "ce"
+        y = x
+        y.dims = int32(s)
+        y(LI2).entries = x(LI).entries
+    else
+        y(LI2) = x(LI)
+        y = matrix(y, s)
end

-    y = matrix(y, ysize)
-
endfunction
index 6e94f3f..75158b1 100644 (file)
@@ -19,9 +19,15 @@ assert_checkequal(y, x');
x = string(x);
y = permute(x, [2 1]);
assert_checkequal(y, x');
+// With a complex matrix
+x = [1 2 3; 4 5 6]*%i;
+y = permute(x, [2 1]);
+refY = [1 4; 2 5; 3 6]*%i;
+assert_checkequal(y, refY);
// With a real hypermatrix
x = matrix(1:12, [2, 3, 2]);
y = permute(x, [3 1 2]);
+clear refY
refY(:, :, 1) = [1 2; 7 8];
refY(:, :, 2) = [3 4; 9 10];
refY(:, :, 3) = [5 6; 11 12];
@@ -36,6 +42,14 @@ x = string(x);
y = permute(x, [3 1 2]);
refY = string(refY);
assert_checkequal(y, refY);
+// With a complex hypermatrix
+x = matrix(1:12, [2, 3, 2])*%i;
+y = permute(x, [3 1 2]);
+clear refY
+refY(:, :, 1) = [1 2; 7 8]*%i;
+refY(:, :, 2) = [3 4; 9 10]*%i;
+refY(:, :, 3) = [5 6; 11 12]*%i;
+assert_checkequal(y, refY);
// Error checks
refMsg = msprintf(_("%s: Wrong size for input argument #%d: At least the size of input argument #%d expected.\n"), "permute", 2, 1);
assert_checkerror("permute(x, [1 2]);", refMsg);
index fc16f39..1dc48c9 100644 (file)
@@ -26,9 +26,17 @@ y = permute(x, [2 1]);

assert_checkequal(y, x');

+// With a complex matrix
+x = [1 2 3; 4 5 6]*%i;
+y = permute(x, [2 1]);
+refY = [1 4; 2 5; 3 6]*%i;
+
+assert_checkequal(y, refY);
+
// With a real hypermatrix
x = matrix(1:12, [2, 3, 2]);
y = permute(x, [3 1 2]);
+clear refY
refY(:, :, 1) = [1 2; 7 8];
refY(:, :, 2) = [3 4; 9 10];
refY(:, :, 3) = [5 6; 11 12];
@@ -49,6 +57,16 @@ refY = string(refY);

assert_checkequal(y, refY);

+// With a complex hypermatrix
+x = matrix(1:12, [2, 3, 2])*%i;
+y = permute(x, [3 1 2]);
+clear refY
+refY(:, :, 1) = [1 2; 7 8]*%i;
+refY(:, :, 2) = [3 4; 9 10]*%i;
+refY(:, :, 3) = [5 6; 11 12]*%i;
+
+assert_checkequal(y, refY);
+
// Error checks
refMsg = msprintf(_("%s: Wrong size for input argument #%d: At least the size of input argument #%d expected.\n"), "permute", 2, 1);
assert_checkerror("permute(x, [1 2]);", refMsg);