* Bug 15737 fixed: setdiff() dit not support complex numbers
[scilab.git] / scilab / modules / elementary_functions / macros / setdiff.sci
index 88fcb6d..0ddf120 100644 (file)
@@ -16,8 +16,9 @@ function [a, ka] = setdiff(a, b, orien)
 
     // History:
     // * 2018 - S. Gougeon : orien="r"|"c" added, including the hypermat case
+    // * 2019 - S. Gougeon : complex numbers supported
 
-    rhs = argn(2);
+    [lhs, rhs] = argn();
 
     // CHECKING INPUT ARGUMENTS
     // ========================
@@ -52,6 +53,7 @@ function [a, ka] = setdiff(a, b, orien)
 
     // PROCESSING
     // ==========
+    Complexes = (type(a)==1 && ~isreal(a)) | (type(b)==1 && ~isreal(b));
     // "r" or "c"
     // ----------
     if orien then
@@ -61,41 +63,84 @@ function [a, ka] = setdiff(a, b, orien)
         if ndims(b) > 2 then
             b = serialize_hypermat(b, orien)
         end
-        [a, ka] = unique(a, orien)
-        if isempty(b)
-            return
-        end
-        it = inttype(a)
-        b = unique(b, orien)
-        if orien==2
-            a = a.'
-            b = b.'
-        end
-        [c, kc] = gsort([[a iconvert(ones(a(:,1)),it)] ;
-                         [b iconvert(ones(b(:,1))*2,it)]], "lr","i")
-        k = find(or(c(1:$-1,1:$-1)~=c(2:$,1:$-1),"c") & c(1:$-1,$)==1)
-        if c($,$)==1
-            k = [k size(c,1)]
-        end
-        ka = ka(kc(k))
-        // a = a(ka,:) // in initial order
-        a = c(k,1:$-1)
-        if orien==2
-            ka = matrix(ka, 1, -1)
-            a = a.'
+        if lhs > 1
+            [a, ka] = unique(a, orien)
+            if isempty(b)
+                return
+            end
+            it = inttype(a)
+            b = unique(b, orien)
+            if orien==2
+                a = a.'
+                b = b.'
+            end
+            if Complexes
+                [c, kc] = gsort([[a iconvert(ones(a(:,1)),it)] ;
+                                 [b iconvert(ones(b(:,1))*2,it)]], ..
+                                "lr", ["i" "i"], list(abs, atan))
+            else
+                [c, kc] = gsort([[a iconvert(ones(a(:,1)),it)] ;
+                                 [b iconvert(ones(b(:,1))*2,it)]], "lr","i")
+            end
+            k = find(or(c(1:$-1,1:$-1)~=c(2:$,1:$-1),"c") & c(1:$-1,$)==1)
+            if c($,$)==1
+                k = [k size(c,1)]
+            end
+            ka = ka(kc(k))
+            // a = a(ka,:) // in initial order
+            a = c(k,1:$-1)
+            if orien==2
+                ka = matrix(ka, 1, -1)
+                a = a.'
+            end
+        else
+            a = unique(a, orien)
+            if isempty(b)
+                return
+            end
+            it = inttype(a)
+            b = unique(b, orien)
+            if orien==2
+                a = a.'
+                b = b.'
+            end
+            if Complexes
+                c = gsort([[a iconvert(ones(a(:,1)),it)] ;
+                           [b iconvert(ones(b(:,1))*2,it)]], ..
+                           "lr", ["i" "i"], list(abs, atan))
+            else
+                c = gsort([[a iconvert(ones(a(:,1)),it)] ;
+                           [b iconvert(ones(b(:,1))*2,it)]], "lr","i")
+            end
+            k = find(or(c(1:$-1,1:$-1)~=c(2:$,1:$-1),"c") & c(1:$-1,$)==1)
+            if c($,$)==1
+                k = [k size(c,1)]
+            end
+            // a = a(ka,:) // in initial order
+            a = c(k,1:$-1)
+            if orien==2
+                a = a.'
+            end
         end
+
     else
         // by element
         // ----------
-        [a,ka] = unique(a);
+        if lhs > 1
+            [a,ka] = unique(a);
+        else
+            a = unique(a);
+        end
         na = size(a,"*");
         if isempty(b)
             return
         end
-
         b = unique(b(:));
-
-        [x,k] = gsort([a(:); b], "g", "i");
+        if Complexes
+            [x,k] = gsort([a(:); b], "g", ["i" "i"], list(abs, atan));
+        else
+            [x,k] = gsort([a(:); b], "g", "i");
+        end
         d = find(x(2:$)==x(1:$-1));  //index of common entries in sorted table
         if d <> [] then
             k([d;d+1]) = [];
@@ -103,7 +148,9 @@ function [a, ka] = setdiff(a, b, orien)
 
         keep = find(k <= na);
         a = a(k(keep));
-        ka = ka(k(keep));
+        if lhs > 1
+            ka = ka(k(keep))
+        end
     end
 endfunction
 // ----------------------------------------------------------------------------