2 * Scilab ( http://www.scilab.org/ ) - This file is part of Scilab
3 * Copyright (C) 2018- Stéphane MOTTELET
5 * This file is hereby licensed under the terms of the GNU GPL v2.0,
6 * For more information, see the COPYING file which you should have received
7 * along with this program.
14 #include "polynom.hxx"
17 #include "function.hxx"
18 #include "overload.hxx"
23 #include "localization.h"
26 void computeOffsets(int iDims, const int* piDimsArray, const std::vector<int>& dimsVect, int* piOffset, int* piMaxOffset)
29 for (int i = 0; i < iDims; ++i)
31 int j = dimsVect[i] - 1;
32 piOffset[j] = iOffset;
33 piMaxOffset[j] = iOffset * piDimsArray[j];
34 iOffset *= piDimsArray[dimsVect[i] - 1];
39 T *doNativePermute(T *pIn, const std::vector<int>& dimsVect)
41 int iDims = pIn->getDims();
42 int* piDimsArray = pIn->getDimsArray();
43 int* piIndex = new int[iDims]();
44 int* piOffset = new int[iDims];
45 int* piMaxOffset = new int[iDims];
47 computeOffsets(iDims, piDimsArray, dimsVect, piOffset, piMaxOffset);
49 T* pOut = pIn->clone();
50 typename T::type* pout = pOut->get();
54 typename T::type* poutImg = pOut->getImg();
55 for (typename T::type *pin = pIn->get(), *pinImg = pIn->getImg(); pin < pIn->get() + pIn->getSize(); pin++, pinImg++)
59 for (int j = 0; j < iDims; j++)
63 poutImg += piOffset[j];
64 if (piIndex[j] < piDimsArray[j])
69 pout -= piMaxOffset[j];
70 poutImg -= piMaxOffset[j];
77 for (typename T::type *pin = pIn->get(); pin < pIn->get() + pIn->getSize(); pin++)
80 for (int j = 0; j < iDims; j++)
84 if (piIndex[j] < piDimsArray[j])
89 pout -= piMaxOffset[j];
102 template <typename T>
103 T* doPermute(T* pIn, const std::vector<int>& dimsVect)
105 int iDims = pIn->getDims();
106 int* piDimsArray = pIn->getDimsArray();
107 int* piOffset = new int[iDims];
108 int* piMaxOffset = new int[iDims];
109 int* piIndex = new int[iDims]();
111 computeOffsets(iDims, piDimsArray, dimsVect, piOffset, piMaxOffset);
113 T* pOut = pIn->clone();
115 for (int iSource = 0, iDest = 0; iSource < pIn->getSize(); iSource++)
117 pOut->set(iDest, pIn->get(iSource));
118 for (int j = 0; j < iDims; j++)
121 iDest += piOffset[j];
122 if (piIndex[j] < piDimsArray[j])
127 iDest -= piMaxOffset[j];
134 delete[] piMaxOffset;
139 types::Function::ReturnValue sci_permute(types::typed_list& in, int _iRetCount, types::typed_list& out)
143 Scierror(77, _("%s: Wrong number of input argument(s): %d expected.\n"), "permute", 2);
144 return types::Function::Error;
149 Scierror(78, _("%s: Wrong number of output argument(s): %d expected."), "permute", 1);
150 return types::Function::Error;
153 if (in[0]->isArrayOf() == false)
155 std::wstring wstFuncName = L"%" + in[0]->getShortTypeStr() + L"_permute";
156 return Overload::call(wstFuncName, in, _iRetCount, out);
159 types::GenericType* pIn = in[0]->getAs<types::GenericType>();
160 types::GenericType* pDims = in[1]->getAs<types::GenericType>();
162 int iDims = pIn->getDims();
163 int* piDimsArray = pIn->getDimsArray();
164 int iNewDims = pDims->getSize();
165 int* piNewDimsArray = NULL;
166 std::vector<int> dimsVect;
168 if ((iNewDims >= iDims) & pDims->isDouble() & !pDims->getAs<types::Double>()->isComplex())
170 // Check if 2nd argument is a permutation of [1..iNewDims]
171 types::Double* pDbl = pDims->getAs<types::Double>();
172 std::vector<double> sortedNewDimsVect(pDbl->get(), pDbl->get() + iNewDims);
173 std::sort(sortedNewDimsVect.begin(), sortedNewDimsVect.end());
174 std::vector<double> rangeVect(iNewDims);
175 std::iota(rangeVect.begin(), rangeVect.end(), 1.0);
177 if (sortedNewDimsVect == rangeVect)
179 piNewDimsArray = new int[iNewDims];
180 for (int i = 0; i < iNewDims; i++)
182 int j = (int)pDbl->get(i);
183 piNewDimsArray[i] = 1;
186 piNewDimsArray[i] = piDimsArray[j - 1];
187 dimsVect.push_back(j);
193 if (dimsVect.empty())
195 delete[] piNewDimsArray;
196 Scierror(78, _("%s: Wrong value for input argument #%d: Must be a valid permutation of [1..n>%d] integers.\n"), "permute", 2, iDims - 1);
197 return types::Function::Error;
200 types::GenericType *pOut;
202 switch (in[0]->getType())
204 case types::InternalType::ScilabDouble:
206 pOut = doNativePermute(in[0]->getAs<types::Double>(), dimsVect);
209 case types::InternalType::ScilabUInt64:
211 pOut = doNativePermute(in[0]->getAs<types::UInt64>(), dimsVect);
214 case types::InternalType::ScilabInt64:
216 pOut = doNativePermute(in[0]->getAs<types::Int64>(), dimsVect);
219 case types::InternalType::ScilabUInt32:
221 pOut = doNativePermute(in[0]->getAs<types::UInt32>(), dimsVect);
224 case types::InternalType::ScilabInt32:
226 pOut = doNativePermute(in[0]->getAs<types::Int32>(), dimsVect);
229 case types::InternalType::ScilabUInt16:
231 pOut = doNativePermute(in[0]->getAs<types::UInt16>(), dimsVect);
234 case types::InternalType::ScilabInt16:
236 pOut = doNativePermute(in[0]->getAs<types::Int16>(), dimsVect);
239 case types::InternalType::ScilabUInt8:
241 pOut = doNativePermute(in[0]->getAs<types::UInt8>(), dimsVect);
244 case types::InternalType::ScilabInt8:
246 pOut = doNativePermute(in[0]->getAs<types::Int8>(), dimsVect);
249 case types::InternalType::ScilabBool:
251 pOut = doNativePermute(in[0]->getAs<types::Bool>(), dimsVect);
254 case types::InternalType::ScilabString:
256 pOut = doPermute(in[0]->getAs<types::String>(), dimsVect);
259 case types::InternalType::ScilabPolynom:
261 pOut = doPermute(in[0]->getAs<types::Polynom>(), dimsVect);
264 case types::InternalType::ScilabStruct:
266 pOut = doPermute(in[0]->getAs<types::Struct>(), dimsVect);
269 case types::InternalType::ScilabCell:
271 pOut = doPermute(in[0]->getAs<types::Cell>(), dimsVect);
276 delete[] piNewDimsArray;
277 std::wstring wstFuncName = L"%" + in[0]->getShortTypeStr() + L"_permute";
278 return Overload::call(wstFuncName, in, _iRetCount, out);
282 pOut->reshape(piNewDimsArray, iNewDims);
284 delete[] piNewDimsArray;
288 return types::Function::OK;