kron() with hypermats: new algo 10x faster 76/21076/3
Samuel GOUGEON [Tue, 27 Aug 2019 16:44:50 +0000 (18:44 +0200)]
 test_run ast kron
 test_run elementary_functions bug_13339

a = rand(4,4);
b = rand(500,500,4);
tic()
r1 = a .*. b;
disp(toc())
tic();
r2 = newKron(a,b);
disp(toc());
disp(and(r1==r2))

--> exec('test_10464.sce', -1)
   63.77736
   5.2832192   = 12.1 x faster
  T

With
a = rand(40,40,2);
b = rand(25,25,4);
// I get
   18.998866
   1.7675572  = 10.7x faster

Change-Id: I6344b7aa0d55bbedcb0a53bfab18dfcfb9b9a80e

scilab/modules/ast/tests/unit_tests/kron.dia.ref [deleted file]
scilab/modules/ast/tests/unit_tests/kron.tst
scilab/modules/overloading/macros/%s_k_s.sci

diff --git a/scilab/modules/ast/tests/unit_tests/kron.dia.ref b/scilab/modules/ast/tests/unit_tests/kron.dia.ref
deleted file mode 100644 (file)
index 3c600e3..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-// ============================================================================
-// Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
-// Copyright (C) 2012 - DIGITEO - Antoine ELIAS
-//
-//  This file is distributed under the same license as the Scilab package.
-// ============================================================================
-// <-- CLI SHELL MODE -->
-r = 2;
-R = [1,2;3,4];
-c = 1 + 2*%i;
-C = [1+2*%i,2+4*%i;3+6*%i,4+8*%i];
-// double .*. double
-//r .*. r
-assert_checkequal(r .*. r, 4);
-//r .*. c
-assert_checkequal(r .*. c, 2+4*%i);
-//c .*. r
-assert_checkequal(c .*. r, 2+4*%i);
-//c .*. c
-assert_checkequal(c .*. c, -3+4*%i);
-// double .*. DOUBLE
-//r .*. R
-assert_checkequal(r .*. R, [2,4;6,8]);
-//r .*. C
-assert_checkequal(r .*. C, [2+4*%i,4+8*%i;6+12*%i,8+16*%i]);
-//c .*. R
-assert_checkequal(c .*. R, [1+2*%i,2+4*%i;3+6*%i,4+8*%i]);
-//c .*. C
-assert_checkequal(c .*. C, [-3+4*%i,-6+8*%i;-9+12*%i,-12+16*%i]);
-// DOUBLE .*. double
-//R .*. r
-assert_checkequal(R .*.r, [2,4;6,8]);
-//R .*. c
-assert_checkequal(R .*.c, [1+2*%i,2+4*%i;3+6*%i,4+8*%i]);
-//C .*. r
-assert_checkequal(C .*.r, [2+4*%i,4+8*%i;6+12*%i,8+16*%i]);
-//C .*. c
-assert_checkequal(C .*.c, [-3+4*%i,-6+8*%i;-9+12*%i,-12+16*%i]);
-// DOUBLE .*. DOUBLE
-//R .*. R
-assert_checkequal(R .*. R, [1,2,2,4;3,4,6,8;3,6,4,8;9,12,12,16]);
-//R .*. C
-assert_checkequal(R .*. C, [1+2*%i,2+4*%i,2+4*%i,4+8*%i;3+6*%i,4+8*%i,6+12*%i,8+16*%i;3+6*%i,6+12*%i,4+8*%i,8+16*%i;9+18*%i,12+24*%i,12+24*%i,16+32*%i]);
-//C .*. R
-assert_checkequal(C .*. R, [1+2*%i,2+4*%i,2+4*%i,4+8*%i;3+6*%i,4+8*%i,6+12*%i,8+16*%i;3+6*%i,6+12*%i,4+8*%i,8+16*%i;9+18*%i,12+24*%i,12+24*%i,16+32*%i]);
-//C .*. C
-assert_checkequal(C .*. C, [-3+4*%i,-6+8*%i,-6+8*%i,-12+16*%i;-9+12*%i,-12+16*%i,-18+24*%i,-24+32*%i;-9+12*%i,-18+24*%i,-12+16*%i,-24+32*%i;-27+36*%i,-36+48*%i,-36+48*%i,-48+64*%i]);
index dd274d6..cbe6213 100644 (file)
@@ -1,11 +1,13 @@
 // ============================================================================
 // Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
 // Copyright (C) 2012 - DIGITEO - Antoine ELIAS
+// Copyright (C) 2019 - Samuel GOUGEON
 //
 //  This file is distributed under the same license as the Scilab package.
 // ============================================================================
 
 // <-- CLI SHELL MODE -->
+// <-- NO CHECK REF -->
 
 r = 2;
 R = [1,2;3,4];
@@ -58,3 +60,22 @@ assert_checkequal(R .*. C, [1+2*%i,2+4*%i,2+4*%i,4+8*%i;3+6*%i,4+8*%i,6+12*%i,8+
 assert_checkequal(C .*. R, [1+2*%i,2+4*%i,2+4*%i,4+8*%i;3+6*%i,4+8*%i,6+12*%i,8+16*%i;3+6*%i,6+12*%i,4+8*%i,8+16*%i;9+18*%i,12+24*%i,12+24*%i,16+32*%i]);
 //C .*. C
 assert_checkequal(C .*. C, [-3+4*%i,-6+8*%i,-6+8*%i,-12+16*%i;-9+12*%i,-12+16*%i,-18+24*%i,-24+32*%i;-9+12*%i,-18+24*%i,-12+16*%i,-24+32*%i;-27+36*%i,-36+48*%i,-36+48*%i,-48+64*%i]);
+
+
+// With hypermatrices
+// ------------------
+r = [1 2 3];
+c = [1 2 3]';
+m = [1 2 ; 3 4];
+h = cat(3,[1 2],[3 4]);
+assert_checkequal(r.*.ones(1,2,2), cat(3,[1 1 2 2 3 3],[1 1 2 2 3 3]));
+assert_checkequal(c.*.ones(1,2,2), cat(3,[1 1;2 2;3 3],[1 1;2 2;3 3]));
+assert_checkequal(m.*.ones(1,2,2), cat(3,[1 1 2 2;3 3 4 4],[1 1 2 2;3 3 4 4]));
+assert_checkequal(h.*.ones(1,2,2), cat(3,[1 1 2 2],[1 1 2 2],[3 3 4 4],[3 3 4 4]));
+
+assert_checkequal(ones(1,2,2).*.r, cat(3,[1 2 3 1 2 3],[1 2 3 1 2 3]));
+assert_checkequal(ones(1,2,2).*.c, cat(3,[1 1;2 2;3 3],[1 1;2 2;3 3]));
+assert_checkequal(ones(1,2,2).*.m, cat(3,[1 2 1 2;3 4 3 4],[1 2 1 2;3 4 3 4]));
+assert_checkequal(ones(1,2,2).*.h, cat(3,[1 2 1 2],[3 4 3 4],[1 2 1 2],[3 4 3 4]));
+
+assert_checkequal(size(rand(2,3,4,5).*.rand(3,1,1,1,1,7)), [6 3 4 5 1 7]);
index 9e5f1c4..8b5bf2e 100644 (file)
@@ -1,7 +1,7 @@
 // Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
-// Copyright (C) 2014 - Samuel GOUGEON
 //
 // Copyright (C) 2012 - 2016 - Scilab Enterprises
+// Copyright (C) 2014, 2019 - Samuel GOUGEON
 //
 // This file is hereby licensed under the terms of the GNU GPL v2.0,
 // pursuant to article 5.3.4 of the CeCILL v.2.1.
 // along with this program.
 
 function r = %s_k_s(a, b)
-    
-    // Code used by %s_k_hm(), %hm_k_s() and %hm_k_hm()
-    // Fully rewritten and factorized after bug http://bugzilla.scilab.org/13339
-
-    sa = size(a)
-    sb = size(b)
-    sa = [sa ones(1,ndims(b)-ndims(a))]
-    sb = [sb ones(1,ndims(a)-ndims(b))]
-    La = length(a)
-    Lb = length(b)
-    a = a(:)
-    b = b(:)
-    ia = (1:La).' .*.ones(b);
-    ib = ones(a) .*. (1:Lb).';
-    ir = (ia-1).*Lb + ib;
-    pa = ind2sub(sa, ia)
-    pb = ind2sub(sb, ib)
-    clear ia ib
-    pr = (pa-1).*repmat(sb, La*Lb, 1) + pb
-    clear pa pb
-    nir = sub2ind(sa.*sb, pr)
-    [v,k] = gsort(nir,"g","i")
-    clear pr nir v
-    r = a.*.b
-    r = matrix(r(ir(k)), sa.*sb)
+    // The b block is replicated according to the a's size and values as weights
+
+    // Code used instead of %s_k_hm(), %hm_k_s() and %hm_k_hm()
+
+    // Computes the size of the result
+    sa = size(a); na = length(sa);
+    sb = size(b); nb = length(sb);
+    m = max(na, nb);
+    if na < m
+        sa = [sa ones(1,m-na)]
+    else
+        sb = [sb ones(1,m-nb)]
+    end
+    sc = sa .* sb;
+
+    // Computes the matrice of indices shifts within each block, in the result:
+    // We use the first block
+    ijk = ind2sub(size(b), 1:size(b,"*"))
+    if na > nb then
+        ijk = [ijk ones(size(b,"*"), na-nb)]
+    end
+    shifts = sub2ind(sc, ijk) - 1;
+
+    // Computes the index of the first element of each block, for all blocks in
+    //  the result:
+    ijk = ind2sub(size(a), 1:size(a,"*")) // indices in a
+    ijk = ijk - 1
+    for u = 1:size(ijk, 2)
+        ijk(:,u) = ijk(:,u) * sb(u)
+    end
+    if na < nb then
+        ijk = [ijk zeros(size(a,"*"), nb-na)]
+    end
+    first = sub2ind(sc, ijk + 1)
+
+    // Computes new indices
+    newI = shifts(:) * ones(1, size(a,"*"))     // Replicates shifts
+    newI = newI(:) + ..
+          (ones(size(b,"*"),1) * first(:)')(:) // Replicates base indices (1st elements)
+    clear shifts first ijk
+
+    // Replicates and weights data
+    r = b(:) * matrix(a, 1, -1);
+
+    // Reallocates elements
+    r(newI) = r(:)
+    clear newI
+
+    // Reshape the result
+    r = matrix(r, sc);
 endfunction