361bafbe892ae1a6665fbe2584d9c29657a800ad
[scilab.git] / scilab / modules / elementary_functions / sci_gateway / cpp / sci_permute.cpp
1 /*
2  *  Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3  *  Copyright (C) 2018- St├ęphane MOTTELET
4  *
5  * This file is hereby licensed under the terms of the GNU GPL v2.0,
6  * For more information, see the COPYING file which you should have received
7  * along with this program.
8  *
9  */
10
11 #include <numeric>
12 #include "double.hxx"
13 #include "int.hxx"
14 #include "polynom.hxx"
15 #include "cell.hxx"
16 #include "struct.hxx"
17 #include "function.hxx"
18 #include "overload.hxx"
19
20 extern "C"
21 {
22 #include "Scierror.h"
23 #include "localization.h"
24 }
25
26 void computeOffsets(int iDims, const int* piDimsArray, const std::vector<int>& dimsVect, int* piOffset, int* piMaxOffset)
27 {
28     for (int i = 0; i < iDims; ++i)
29     {
30         int iOffset = i > 0 ? iOffset * piDimsArray[dimsVect[i - 1] - 1] : 1;
31         int j = dimsVect[i] - 1;
32         piOffset[j] = iOffset;
33         piMaxOffset[j] = iOffset * piDimsArray[j];
34     }
35 }
36
37 template <typename T>
38 T *doNativePermute(T *pIn, const std::vector<int>& dimsVect)
39 {
40     int iDims = pIn->getDims();
41     int* piDimsArray = pIn->getDimsArray();
42     int* piIndex = new int[iDims]();
43     int* piOffset = new int[iDims];
44     int* piMaxOffset = new int[iDims];
45
46     computeOffsets(iDims, piDimsArray, dimsVect, piOffset, piMaxOffset);
47
48     T* pOut = pIn->clone();
49     typename T::type* pout = pOut->get();
50
51     if (pIn->isComplex())
52     {
53         typename T::type* poutImg = pOut->getImg();
54         for (typename T::type *pin = pIn->get(), *pinImg = pIn->getImg(); pin < pIn->get() + pIn->getSize(); pin++, pinImg++)
55         {
56             *pout = *pin;
57             *poutImg = *pinImg;
58             for (int j = 0; j < iDims; j++)
59             {
60                 ++piIndex[j];
61                 pout += piOffset[j];
62                 poutImg += piOffset[j];
63                 if (piIndex[j] < piDimsArray[j])
64                 {
65                     break;
66                 }
67
68                 pout -= piMaxOffset[j];
69                 poutImg -= piMaxOffset[j];
70                 piIndex[j] = 0;
71             }
72         }
73     }
74     else
75     {
76         for (typename T::type *pin = pIn->get(); pin < pIn->get() + pIn->getSize(); pin++)
77         {
78             *pout = *pin;
79             for (int j = 0; j < iDims; j++)
80             {
81                 ++piIndex[j];
82                 pout += piOffset[j];
83                 if (piIndex[j] < piDimsArray[j])
84                 {
85                     break;
86                 }
87
88                 pout -= piMaxOffset[j];
89                 piIndex[j] = 0;
90             }
91         }
92     }
93
94     delete[] piIndex;
95     delete[] piOffset;
96     delete[] piMaxOffset;
97
98     return pOut;
99 }
100
101 template <typename T>
102 T* doPermute(T* pIn, const std::vector<int>& dimsVect)
103 {
104     int iDims = pIn->getDims();
105     int* piDimsArray = pIn->getDimsArray();
106     int* piOffset = new int[iDims];
107     int* piMaxOffset = new int[iDims];
108     int* piIndex = new int[iDims]();
109
110     computeOffsets(iDims, piDimsArray, dimsVect, piOffset, piMaxOffset);
111
112     T* pOut = pIn->clone();
113
114     for (int iSource = 0, iDest = 0; iSource < pIn->getSize(); iSource++)
115     {
116         pOut->set(iDest, pIn->get(iSource));
117         for (int j = 0; j < iDims; j++)
118         {
119             ++piIndex[j];
120             iDest += piOffset[j];
121             if (piIndex[j] < piDimsArray[j])
122             {
123                 break;
124             }
125
126             iDest -= piMaxOffset[j];
127             piIndex[j] = 0;
128         }
129     }
130
131     delete[] piIndex;
132     delete[] piOffset;
133     delete[] piMaxOffset;
134
135     return pOut;
136 }
137
138 types::Function::ReturnValue sci_permute(types::typed_list& in, int _iRetCount, types::typed_list& out)
139 {
140     if (in.size() != 2)
141     {
142         Scierror(77, _("%s: Wrong number of input argument(s): %d expected.\n"), "permute", 2);
143         return types::Function::Error;
144     }
145
146     if (_iRetCount > 1)
147     {
148         Scierror(78, _("%s: Wrong number of output argument(s): %d expected."), "permute", 1);
149         return types::Function::Error;
150     }
151
152     if (in[0]->isArrayOf() == false)
153     {
154         std::wstring wstFuncName = L"%" + in[0]->getShortTypeStr() + L"_permute";
155         return Overload::call(wstFuncName, in, _iRetCount, out);
156     }
157
158     types::GenericType* pIn = in[0]->getAs<types::GenericType>();
159     types::GenericType* pDims = in[1]->getAs<types::GenericType>();
160
161     int iDims = pIn->getDims();
162     int* piDimsArray = pIn->getDimsArray();
163     int iNewDims = pDims->getSize();
164     int* piNewDimsArray = NULL;
165     std::vector<int> dimsVect;
166
167     if ((iNewDims >= iDims) & pDims->isDouble() & !pDims->getAs<types::Double>()->isComplex())
168     {
169         // Check if 2nd argument is a permutation of [1..iNewDims]
170         types::Double* pDbl = pDims->getAs<types::Double>();
171         std::vector<double> sortedNewDimsVect(pDbl->get(), pDbl->get() + iNewDims);
172         std::sort(sortedNewDimsVect.begin(), sortedNewDimsVect.end());
173         std::vector<double> rangeVect(iNewDims);
174         std::iota(rangeVect.begin(), rangeVect.end(), 1.0);
175
176         if (sortedNewDimsVect == rangeVect)
177         {
178             piNewDimsArray = new int[iNewDims];
179             for (int i = 0; i < iNewDims; i++)
180             {
181                 int j = (int)pDbl->get(i);
182                 piNewDimsArray[i] = 1;
183                 if (j <= iDims)
184                 {
185                     piNewDimsArray[i] = piDimsArray[j - 1];
186                     dimsVect.push_back(j);
187                 }
188             }
189         }
190     }
191
192     if (dimsVect.empty())
193     {
194         delete[] piNewDimsArray;
195         Scierror(78, _("%s: Wrong value for input argument #%d: Must be a valid permutation of [1..n>%d] integers.\n"), "permute", 2, iDims - 1);
196         return types::Function::Error;
197     }
198
199     types::GenericType *pOut;
200
201     switch (in[0]->getType())
202     {
203         case types::InternalType::ScilabDouble:
204         {
205             pOut = doNativePermute(in[0]->getAs<types::Double>(), dimsVect);
206             break;
207         }
208         case types::InternalType::ScilabUInt64:
209         {
210             pOut = doNativePermute(in[0]->getAs<types::UInt64>(), dimsVect);
211             break;
212         }
213         case types::InternalType::ScilabInt64:
214         {
215             pOut = doNativePermute(in[0]->getAs<types::Int64>(), dimsVect);
216             break;
217         }
218         case types::InternalType::ScilabUInt32:
219         {
220             pOut = doNativePermute(in[0]->getAs<types::UInt32>(), dimsVect);
221             break;
222         }
223         case types::InternalType::ScilabInt32:
224         {
225             pOut = doNativePermute(in[0]->getAs<types::Int32>(), dimsVect);
226             break;
227         }
228         case types::InternalType::ScilabUInt16:
229         {
230             pOut = doNativePermute(in[0]->getAs<types::UInt16>(), dimsVect);
231             break;
232         }
233         case types::InternalType::ScilabInt16:
234         {
235             pOut = doNativePermute(in[0]->getAs<types::Int16>(), dimsVect);
236             break;
237         }
238         case types::InternalType::ScilabUInt8:
239         {
240             pOut = doNativePermute(in[0]->getAs<types::UInt8>(), dimsVect);
241             break;
242         }
243         case types::InternalType::ScilabInt8:
244         {
245             pOut = doNativePermute(in[0]->getAs<types::Int8>(), dimsVect);
246             break;
247         }
248         case types::InternalType::ScilabBool:
249         {
250             pOut = doNativePermute(in[0]->getAs<types::Bool>(), dimsVect);
251             break;
252         }
253         case types::InternalType::ScilabString:
254         {
255             pOut = doPermute(in[0]->getAs<types::String>(), dimsVect);
256             break;
257         }
258         case types::InternalType::ScilabPolynom:
259         {
260             pOut = doPermute(in[0]->getAs<types::Polynom>(), dimsVect);
261             break;
262         }
263         case types::InternalType::ScilabStruct:
264         {
265             pOut = doPermute(in[0]->getAs<types::Struct>(), dimsVect);
266             break;
267         }
268         case types::InternalType::ScilabCell:
269         {
270             pOut = doPermute(in[0]->getAs<types::Cell>(), dimsVect);
271             break;
272         }
273         default:
274         {
275             delete[] piNewDimsArray;
276             std::wstring wstFuncName = L"%" + in[0]->getShortTypeStr() + L"_permute";
277             return Overload::call(wstFuncName, in, _iRetCount, out);
278         }
279     }
280
281     pOut->reshape(piNewDimsArray, iNewDims);
282
283     delete[] piNewDimsArray;
284
285     out.push_back(pOut);
286
287     return types::Function::OK;
288 }