* Bug 16274 fixed: assert_checkequal() with Nan or void in containers
[scilab.git] / scilab / modules / development_tools / macros / assert / assert_checkequal.sci
1 // Copyright (C) 2008-2009 - INRIA - Michael Baudin
2 // Copyright (C) 2010 - 2011 - DIGITEO - Michael Baudin
3 // Copyright (C) 2012 - 2016 - Scilab Enterprises
4 // Copyright (C) 2019 - 2020 - Samuel GOUGEON
5 //
6 // This file is hereby licensed under the terms of the GNU GPL v2.0,
7 // pursuant to article 5.3.4 of the CeCILL v.2.1.
8 // This file was originally licensed under the terms of the CeCILL v2.1,
9 // and continues to be available under such terms.
10 // For more information, see the COPYING file which you should have received
11 // along with this program.
12
13 function [flag, errmsg] = assert_checkequal(computed, expected)
14     //  Check that computed and expected are equal.
15     [lhs,rhs] = argn()
16     if ( rhs <> 2 ) then
17         errmsg = gettext("%s: Wrong number of input arguments: %d expected.\n")
18         error(msprintf(errmsg, "assert_checkequal", 2))
19     end
20
21     // Check types of variables
22     if ( typeof(computed) <> typeof(expected) ) then
23         errmsg = gettext("%s: Incompatible input arguments #%d and #%d: Same types expected.\n")
24         error(msprintf(errmsg, "assert_checkequal", 1, 2))
25     end
26
27     //
28     // Check sizes of variables
29     if type(computed)==15 then
30         ncom = length(computed)
31         nexp = length(expected)
32     elseif or(typeof(computed)==["ce" "st"])
33         ncom = size(computed)
34         nexp = size(expected)
35     else
36         try
37             ncom = size(computed)
38             nexp = size(expected)
39         catch   // non-sizeable objects: 1:$, iolib, sin, sind, etc
40             ncom = -2
41             nexp = -2
42         end
43     end
44     if ( or(ncom <> nexp) ) then
45         errmsg = msprintf(gettext ( "%s: Incompatible input arguments #%d and #%d: Same sizes expected.\n"), "assert_checkequal", 1 , 2)
46         error(errmsg)
47     end
48
49     // sparse or full real or complex matrices
50     if or(type(computed) == [1 5])  then
51         cisreal = isreal(computed)
52         eisreal = isreal(expected)
53         if ( cisreal & ~eisreal ) then
54             errmsg = msprintf(gettext("%s: Computed is real, but expected is complex."), "assert_checkequal")
55             error(errmsg)
56         end
57         if ( ~cisreal & eisreal ) then
58             errmsg = msprintf(gettext("%s: Computed is complex, but expected is real."), "assert_checkequal")
59             error(errmsg)
60         end
61         if cisreal & eisreal then
62             [flag, k] = comparedoubles ( computed , expected )
63         else
64             [flag, k] = comparedoubles ( real(computed) , real(expected) )
65             if flag then
66                 [flag ,k] = comparedoubles ( imag(computed) , imag(expected) )
67             end
68         end
69         // k is the index of the first discrepancy (or [] if none)
70
71     elseif or(typeof(computed)==["implicitlist" "fptr" "function"])
72                                     // http://bugzilla.scilab.org/16104 C) D) E)
73         flag = computed==expected
74         if ~flag then
75             if typeof(computed) == "implicitlist"
76                 errmsg = _("%s: Assertion failed: expected= %s  while computed= %s")
77                 errmsg = msprintf(errmsg,"assert_checkequal",string(expected),string(computed))
78
79             elseif typeof(computed) == "function"
80                 c = macr2tree(computed).name+"()"
81                 e = macr2tree(expected).name+"()"
82                 errmsg = _("%s: Assertion failed: expected= %s  while computed= %s")
83                 errmsg = msprintf(errmsg,"assert_checkequal", e, c)
84
85             else
86                 // no way to get the names of built-in functions
87                 errmsg = _("%s: Assertion failed: expected and computed are two distinct built-in functions.")
88                 errmsg = msprintf(errmsg,"assert_checkequal")
89             end
90             if lhs < 2 then
91                 assert_generror ( errmsg )
92             end
93         end
94         return
95
96     elseif type(computed) == 14   // library : http://bugzilla.scilab.org/16104#c1
97         flag = and(string(computed)==string(expected))
98         if ~flag then
99             errmsg = gettext("%s: Assertion failed: expected= %s  while computed= %s")
100             c = "lib@" + string(computed)(1)
101             e = "lib@" + string(expected)(1)
102             errmsg = msprintf(errmsg,"assert_checkequal", e, c)
103             if lhs < 2 then
104                 assert_generror ( errmsg )
105             end
106         end
107         return
108
109     elseif or(type(computed)==[15 16 17 ])
110         [flag, k] = compareContainers(computed , expected)
111
112     elseif type(computed) == 0
113         flag = %t
114
115     else
116         b = and(computed == expected)
117         flag = b || isequal(computed, expected)
118         if ~flag & ~b
119             k = find(computed<>expected, 1);
120         end
121     end
122
123     if flag then
124         errmsg = ""
125
126     else
127         // Sets the message according to the type and size of the pair:
128         if or(typeof(expected) == ["sparse", "boolean sparse"])
129             estr = string(full(expected(k)))
130         else
131             s = "expected(1)"
132             if isdef("k","l") & k <> []
133                 s = "expected(k)"
134             end
135             err = execstr("e = "+s+"; t = type("+s+")", "errcatch")
136             if err <> 0
137                 e = expected
138                 t = type(e)
139             end
140             if t==0
141                 estr = "(void)"
142             elseif t==9
143                 estr = msprintf("%s(uid:%d)", e.type, e.uid)
144             else
145                 estr = string(e)
146             end
147         end
148         //
149         if or(typeof(computed) == ["sparse", "boolean sparse"])
150             cstr = full(computed(k))
151         else
152             s = "computed(1)"
153             if isdef("k","l") & k <> []
154                 s = "computed(k)"
155             end
156             err = execstr("c = "+s+"; t = type("+s+")", "errcatch")
157             if err <> 0
158                 c = computed
159                 t = type(c)
160             end
161             if t==0
162                 cstr = "(void)"
163             elseif t==9
164                 cstr = msprintf("%s(uid:%d)", c.type, c.uid)
165             else
166                 cstr = string(c)
167             end
168         end
169         //
170         if isdef("k","l") & k <> [] & length(computed)>1
171             estr = msprintf(_("expected(%d)= "),k) + estr
172             cstr = msprintf(_("computed(%d)= "),k) + cstr
173         else
174             estr = _("expected= ") + estr
175             cstr = _("computed= ") + cstr
176         end
177         //
178         ierr = execstr("mdiff = string(mean(computed - expected))", "errcatch");
179         if ( ierr == 0 ) then
180             errmsg = msprintf(gettext("%s: Assertion failed: %s  while %s (mean diff = %s)"),"assert_checkequal",estr, cstr, mdiff)
181         else
182             errmsg = msprintf(gettext("%s: Assertion failed: %s  while %s"),"assert_checkequal", estr, cstr)
183         end
184         if lhs < 2 then
185             // If no output variable is given, generate an error
186             assert_generror ( errmsg )
187         end
188     end
189 endfunction
190 // ---------------------------------------------------------------------------
191 function [flag, k] = comparedoubles ( computed , expected )
192     rand("seed",getdate("s"))
193     joker = rand(1);
194     while find(expected==joker | computed==joker,1)<>[]
195         joker = rand(1);
196     end
197     computed(isnan(computed)) = joker;
198     expected(isnan(expected)) = joker;
199     k = find(expected<>computed,1);
200     flag = k==[];
201 endfunction
202 // ---------------------------------------------------------------------------
203 function [areEqual, k] = compareContainers(computed , expected)
204     // http://bugzilla.scilab.org/15293
205     // http://bugzilla.scilab.org/16274
206     tc = typeof(computed)
207     te = typeof(expected)
208     k = []
209     areEqual = tc == te
210     if ~areEqual
211         return
212     end
213     if or(type(computed)==[1 5])
214         if and(computed == expected)
215             return
216         end
217         if isreal(computed) <> isreal(expected)
218             areEqual = %f
219             return
220         end
221         [areEqual, k] = comparedoubles(real(computed), real(expected))
222         if areEqual
223             [areEqual, k] = comparedoubles(imag(computed), imag(expected))
224         end
225
226     elseif or(type(computed)==[16 17]) then
227         if and(computed == expected)
228             return
229         end
230         if or(size(computed) <> size(expected)) then
231             areEqual = %f
232             return
233         end
234         fc = fieldnames(computed)
235         areEqual = and(fc == fieldnames(expected))
236         if ~areEqual
237             return
238         end
239         if fc <> []
240             for f = fc'
241                 [areEqual, k] = compareContainers(computed(f) , expected(f))
242                 if ~areEqual
243                     break
244                 end
245             end
246         elseif tc=="ce"
247             [areEqual, k] = compareContainers(computed{:} , expected{:})
248             if ~areEqual
249                 break
250             end
251         end
252
253     elseif type(computed)==14   // Libraries
254         areEqual = and(string(computed)==string(expected))
255
256     elseif tc=="list"
257         if and(computed == expected)
258             return
259         end
260         if length(computed) <> length(expected)
261             areEqual = %f
262             return
263         end
264         dfc = definedfields(computed)
265         dfe = definedfields(expected)
266         if or(dfc <> dfe)
267             if length(dfc)==length(dfe)
268                 k = find(dfc <> dfe, 1)
269             else
270                 tmp = union(setdiff(dfc, dfe), setdiff(dfe, dfc))
271                 k = tmp(find(tmp,1))
272             end
273             areEqual = %f
274             return
275         end
276         for k = dfc
277             areEqual = compareContainers(computed(k) , expected(k))
278             if ~areEqual
279                 break
280             end
281         end
282
283     elseif (tc=="void" & te=="void")
284         return
285
286     elseif type(computed) <> 0
287         b = and(computed == expected)
288         areEqual = b || isequal(computed, expected)
289         if ~areEqual & ~b
290             k = find(computed <> expected, 1);
291         end
292     end
293 endfunction