blob: c5a5b54f68b5f0bc16176339ae38c7f5d48a0800 [file] [log] [blame]
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <iomanip>
18#include <iostream>
19#include <cmath>
20#include <sstream>
21
22#include "Generator.h"
23#include "Specification.h"
24#include "Utilities.h"
25
26using namespace std;
27
28// Converts float2 to FLOAT_32 and 2, etc.
29static void convertToRsType(const string& name, string* dataType, char* vectorSize) {
30 string s = name;
31 int last = s.size() - 1;
32 char lastChar = s[last];
33 if (lastChar >= '1' && lastChar <= '4') {
34 s.erase(last);
35 *vectorSize = lastChar;
36 } else {
37 *vectorSize = '1';
38 }
39 dataType->clear();
40 for (int i = 0; i < NUM_TYPES; i++) {
41 if (s == TYPES[i].cType) {
42 *dataType = TYPES[i].rsDataType;
43 break;
44 }
45 }
46}
47
48// Returns true if any permutation of the function have tests to b
Yang Ni12398d82015-09-18 14:57:07 -070049static bool needTestFiles(const Function& function, unsigned int versionOfTestFiles) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -070050 for (auto spec : function.getSpecifications()) {
51 if (spec->hasTests(versionOfTestFiles)) {
52 return true;
53 }
54 }
55 return false;
56}
57
58/* One instance of this class is generated for each permutation of a function for which
59 * we are generating test code. This instance will generate both the script and the Java
60 * section of the test files for this permutation. The class is mostly used to keep track
61 * of the various names shared between script and Java files.
62 * WARNING: Because the constructor keeps a reference to the FunctionPermutation, PermutationWriter
63 * should not exceed the lifetime of FunctionPermutation.
64 */
65class PermutationWriter {
66private:
67 FunctionPermutation& mPermutation;
68
69 string mRsKernelName;
70 string mJavaArgumentsClassName;
71 string mJavaArgumentsNClassName;
72 string mJavaVerifierComputeMethodName;
73 string mJavaVerifierVerifyMethodName;
74 string mJavaCheckMethodName;
75 string mJavaVerifyMethodName;
76
77 // Pointer to the files we are generating. Handy to avoid always passing them in the calls.
78 GeneratedFile* mRs;
79 GeneratedFile* mJava;
80
81 /* Shortcuts to the return parameter and the first input parameter of the function
82 * specification.
83 */
84 const ParameterDefinition* mReturnParam; // Can be nullptr. NOT OWNED.
85 const ParameterDefinition* mFirstInputParam; // Can be nullptr. NOT OWNED.
86
87 /* All the parameters plus the return param, if present. Collecting them together
88 * simplifies code generation. NOT OWNED.
89 */
90 vector<const ParameterDefinition*> mAllInputsAndOutputs;
91
92 /* We use a class to pass the arguments between the generated code and the CoreVerifier. This
93 * method generates this class. The set keeps track if we've generated this class already
94 * for this test file, as more than one permutation may use the same argument class.
95 */
96 void writeJavaArgumentClass(bool scalar, set<string>* javaGeneratedArgumentClasses) const;
97
98 // Generate the Check* method that invokes the script and calls the verifier.
99 void writeJavaCheckMethod(bool generateCallToVerifier) const;
100
101 // Generate code to define and randomly initialize the input allocation.
102 void writeJavaInputAllocationDefinition(const ParameterDefinition& param) const;
103
104 /* Generate code that instantiate an allocation of floats or integers and fills it with
105 * random data. This random data must be compatible with the specified type. This is
106 * used for the convert_* tests, as converting values that don't fit yield undefined results.
107 */
108 void writeJavaRandomCompatibleFloatAllocation(const string& dataType, const string& seed,
109 char vectorSize,
110 const NumericalType& compatibleType,
111 const NumericalType& generatedType) const;
112 void writeJavaRandomCompatibleIntegerAllocation(const string& dataType, const string& seed,
113 char vectorSize,
114 const NumericalType& compatibleType,
115 const NumericalType& generatedType) const;
116
117 // Generate code that defines an output allocation.
118 void writeJavaOutputAllocationDefinition(const ParameterDefinition& param) const;
119
120 /* Generate the code that verifies the results for RenderScript functions where each entry
121 * of a vector is evaluated independently. If verifierValidates is true, CoreMathVerifier
122 * does the actual validation instead of more commonly returning the range of acceptable values.
123 */
124 void writeJavaVerifyScalarMethod(bool verifierValidates) const;
125
126 /* Generate the code that verify the results for a RenderScript function where a vector
127 * is a point in n-dimensional space.
128 */
129 void writeJavaVerifyVectorMethod() const;
130
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800131 // Generate the line that creates the Target.
132 void writeJavaCreateTarget() const;
133
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700134 // Generate the method header of the verify function.
135 void writeJavaVerifyMethodHeader() const;
136
137 // Generate codes that copies the content of an allocation to an array.
138 void writeJavaArrayInitialization(const ParameterDefinition& p) const;
139
140 // Generate code that tests one value returned from the script.
141 void writeJavaTestAndSetValid(const ParameterDefinition& p, const string& argsIndex,
142 const string& actualIndex) const;
143 void writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
144 const string& actualIndex) const;
145 // For test:vector cases, generate code that compares returned vector vs. expected value.
146 void writeJavaVectorComparison(const ParameterDefinition& p) const;
147
148 // Muliple functions that generates code to build the error message if an error is found.
149 void writeJavaAppendOutputToMessage(const ParameterDefinition& p, const string& argsIndex,
150 const string& actualIndex, bool verifierValidates) const;
151 void writeJavaAppendInputToMessage(const ParameterDefinition& p, const string& actual) const;
152 void writeJavaAppendNewLineToMessage() const;
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700153 void writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const;
154 void writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const;
155
156 // Generate the set of instructions to call the script.
157 void writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const;
158
159 // Write an allocation definition if not already emitted in the .rs file.
160 void writeRsAllocationDefinition(const ParameterDefinition& param,
161 set<string>* rsAllocationsGenerated) const;
162
163public:
164 /* NOTE: We keep pointers to the permutation and the files. This object should not
165 * outlive the arguments.
166 */
167 PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
168 GeneratedFile* javaFile);
169 string getJavaCheckMethodName() const { return mJavaCheckMethodName; }
170
171 // Write the script test function for this permutation.
172 void writeRsSection(set<string>* rsAllocationsGenerated) const;
173 // Write the section of the Java code that calls the script and validates the results
174 void writeJavaSection(set<string>* javaGeneratedArgumentClasses) const;
175};
176
177PermutationWriter::PermutationWriter(FunctionPermutation& permutation, GeneratedFile* rsFile,
178 GeneratedFile* javaFile)
179 : mPermutation(permutation),
180 mRs(rsFile),
181 mJava(javaFile),
182 mReturnParam(nullptr),
183 mFirstInputParam(nullptr) {
184 mRsKernelName = "test" + capitalize(permutation.getName());
185
186 mJavaArgumentsClassName = "Arguments";
187 mJavaArgumentsNClassName = "Arguments";
188 const string trunk = capitalize(permutation.getNameTrunk());
189 mJavaCheckMethodName = "check" + trunk;
190 mJavaVerifyMethodName = "verifyResults" + trunk;
191
192 for (auto p : permutation.getParams()) {
193 mAllInputsAndOutputs.push_back(p);
194 if (mFirstInputParam == nullptr && !p->isOutParameter) {
195 mFirstInputParam = p;
196 }
197 }
198 mReturnParam = permutation.getReturn();
199 if (mReturnParam) {
200 mAllInputsAndOutputs.push_back(mReturnParam);
201 }
202
203 for (auto p : mAllInputsAndOutputs) {
204 const string capitalizedRsType = capitalize(p->rsType);
205 const string capitalizedBaseType = capitalize(p->rsBaseType);
206 mRsKernelName += capitalizedRsType;
207 mJavaArgumentsClassName += capitalizedBaseType;
208 mJavaArgumentsNClassName += capitalizedBaseType;
209 if (p->mVectorSize != "1") {
210 mJavaArgumentsNClassName += "N";
211 }
212 mJavaCheckMethodName += capitalizedRsType;
213 mJavaVerifyMethodName += capitalizedRsType;
214 }
215 mJavaVerifierComputeMethodName = "compute" + trunk;
216 mJavaVerifierVerifyMethodName = "verify" + trunk;
217}
218
219void PermutationWriter::writeJavaSection(set<string>* javaGeneratedArgumentClasses) const {
220 // By default, we test the results using item by item comparison.
221 const string test = mPermutation.getTest();
222 if (test == "scalar" || test == "limited") {
223 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
224 writeJavaCheckMethod(true);
225 writeJavaVerifyScalarMethod(false);
226 } else if (test == "custom") {
227 writeJavaArgumentClass(true, javaGeneratedArgumentClasses);
228 writeJavaCheckMethod(true);
229 writeJavaVerifyScalarMethod(true);
230 } else if (test == "vector") {
231 writeJavaArgumentClass(false, javaGeneratedArgumentClasses);
232 writeJavaCheckMethod(true);
233 writeJavaVerifyVectorMethod();
234 } else if (test == "noverify") {
235 writeJavaCheckMethod(false);
236 }
237}
238
239void PermutationWriter::writeJavaArgumentClass(bool scalar,
240 set<string>* javaGeneratedArgumentClasses) const {
241 string name;
242 if (scalar) {
243 name = mJavaArgumentsClassName;
244 } else {
245 name = mJavaArgumentsNClassName;
246 }
247
248 // Make sure we have not generated the argument class already.
249 if (!testAndSet(name, javaGeneratedArgumentClasses)) {
250 mJava->indent() << "public class " << name;
251 mJava->startBlock();
252
253 for (auto p : mAllInputsAndOutputs) {
254 mJava->indent() << "public ";
255 if (p->isOutParameter && p->isFloatType && mPermutation.getTest() != "custom") {
256 *mJava << "Target.Floaty";
257 } else {
258 *mJava << p->javaBaseType;
259 }
260 if (!scalar && p->mVectorSize != "1") {
261 *mJava << "[]";
262 }
263 *mJava << " " << p->variableName << ";\n";
264 }
265 mJava->endBlock();
266 *mJava << "\n";
267 }
268}
269
270void PermutationWriter::writeJavaCheckMethod(bool generateCallToVerifier) const {
271 mJava->indent() << "private void " << mJavaCheckMethodName << "()";
272 mJava->startBlock();
273
274 // Generate the input allocations and initialization.
275 for (auto p : mAllInputsAndOutputs) {
276 if (!p->isOutParameter) {
277 writeJavaInputAllocationDefinition(*p);
278 }
279 }
280 // Generate code to enforce ordering between two allocations if needed.
281 for (auto p : mAllInputsAndOutputs) {
282 if (!p->isOutParameter && !p->smallerParameter.empty()) {
283 string smallerAlloc = "in" + capitalize(p->smallerParameter);
284 mJava->indent() << "enforceOrdering(" << smallerAlloc << ", " << p->javaAllocName
285 << ");\n";
286 }
287 }
288
289 // Generate code to check the full and relaxed scripts.
290 writeJavaCallToRs(false, generateCallToVerifier);
291 writeJavaCallToRs(true, generateCallToVerifier);
292
293 mJava->endBlock();
294 *mJava << "\n";
295}
296
297void PermutationWriter::writeJavaInputAllocationDefinition(const ParameterDefinition& param) const {
298 string dataType;
299 char vectorSize;
300 convertToRsType(param.rsType, &dataType, &vectorSize);
301
302 const string seed = hashString(mJavaCheckMethodName + param.javaAllocName);
303 mJava->indent() << "Allocation " << param.javaAllocName << " = ";
304 if (param.compatibleTypeIndex >= 0) {
305 if (TYPES[param.typeIndex].kind == FLOATING_POINT) {
306 writeJavaRandomCompatibleFloatAllocation(dataType, seed, vectorSize,
307 TYPES[param.compatibleTypeIndex],
308 TYPES[param.typeIndex]);
309 } else {
310 writeJavaRandomCompatibleIntegerAllocation(dataType, seed, vectorSize,
311 TYPES[param.compatibleTypeIndex],
312 TYPES[param.typeIndex]);
313 }
314 } else if (!param.minValue.empty()) {
315 *mJava << "createRandomFloatAllocation(mRS, Element.DataType." << dataType << ", "
316 << vectorSize << ", " << seed << ", " << param.minValue << ", " << param.maxValue
317 << ")";
318 } else {
319 /* TODO Instead of passing always false, check whether we are doing a limited test.
320 * Use instead: (mPermutation.getTest() == "limited" ? "false" : "true")
321 */
322 *mJava << "createRandomAllocation(mRS, Element.DataType." << dataType << ", " << vectorSize
323 << ", " << seed << ", false)";
324 }
325 *mJava << ";\n";
326}
327
328void PermutationWriter::writeJavaRandomCompatibleFloatAllocation(
329 const string& dataType, const string& seed, char vectorSize,
330 const NumericalType& compatibleType, const NumericalType& generatedType) const {
331 *mJava << "createRandomFloatAllocation"
332 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
333 double minValue = 0.0;
334 double maxValue = 0.0;
335 switch (compatibleType.kind) {
336 case FLOATING_POINT: {
337 // We're generating floating point values. We just worry about the exponent.
338 // Subtract 1 for the exponent sign.
339 int bits = min(compatibleType.exponentBits, generatedType.exponentBits) - 1;
340 maxValue = ldexp(0.95, (1 << bits) - 1);
341 minValue = -maxValue;
342 break;
343 }
344 case UNSIGNED_INTEGER:
345 maxValue = maxDoubleForInteger(compatibleType.significantBits,
346 generatedType.significantBits);
347 minValue = 0.0;
348 break;
349 case SIGNED_INTEGER:
350 maxValue = maxDoubleForInteger(compatibleType.significantBits,
351 generatedType.significantBits);
352 minValue = -maxValue - 1.0;
353 break;
354 }
355 *mJava << scientific << std::setprecision(19);
356 *mJava << minValue << ", " << maxValue << ")";
357 mJava->unsetf(ios_base::floatfield);
358}
359
360void PermutationWriter::writeJavaRandomCompatibleIntegerAllocation(
361 const string& dataType, const string& seed, char vectorSize,
362 const NumericalType& compatibleType, const NumericalType& generatedType) const {
363 *mJava << "createRandomIntegerAllocation"
364 << "(mRS, Element.DataType." << dataType << ", " << vectorSize << ", " << seed << ", ";
365
366 if (compatibleType.kind == FLOATING_POINT) {
367 // Currently, all floating points can take any number we generate.
368 bool isSigned = generatedType.kind == SIGNED_INTEGER;
369 *mJava << (isSigned ? "true" : "false") << ", " << generatedType.significantBits;
370 } else {
371 bool isSigned =
372 compatibleType.kind == SIGNED_INTEGER && generatedType.kind == SIGNED_INTEGER;
373 *mJava << (isSigned ? "true" : "false") << ", "
374 << min(compatibleType.significantBits, generatedType.significantBits);
375 }
376 *mJava << ")";
377}
378
379void PermutationWriter::writeJavaOutputAllocationDefinition(
380 const ParameterDefinition& param) const {
381 string dataType;
382 char vectorSize;
383 convertToRsType(param.rsType, &dataType, &vectorSize);
384 mJava->indent() << "Allocation " << param.javaAllocName << " = Allocation.createSized(mRS, "
385 << "getElement(mRS, Element.DataType." << dataType << ", " << vectorSize
386 << "), INPUTSIZE);\n";
387}
388
389void PermutationWriter::writeJavaVerifyScalarMethod(bool verifierValidates) const {
390 writeJavaVerifyMethodHeader();
391 mJava->startBlock();
392
393 string vectorSize = "1";
394 for (auto p : mAllInputsAndOutputs) {
395 writeJavaArrayInitialization(*p);
396 if (p->mVectorSize != "1" && p->mVectorSize != vectorSize) {
397 if (vectorSize == "1") {
398 vectorSize = p->mVectorSize;
399 } else {
400 cerr << "Error. Had vector " << vectorSize << " and " << p->mVectorSize << "\n";
401 }
402 }
403 }
404
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700405 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
406 mJava->indent() << "boolean errorFound = false;\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700407 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
408 mJava->startBlock();
409
410 mJava->indent() << "for (int j = 0; j < " << vectorSize << " ; j++)";
411 mJava->startBlock();
412
413 mJava->indent() << "// Extract the inputs.\n";
414 mJava->indent() << mJavaArgumentsClassName << " args = new " << mJavaArgumentsClassName
415 << "();\n";
416 for (auto p : mAllInputsAndOutputs) {
417 if (!p->isOutParameter) {
418 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i";
419 if (p->vectorWidth != "1") {
420 *mJava << " * " << p->vectorWidth << " + j";
421 }
422 *mJava << "];\n";
423 }
424 }
425 const bool hasFloat = mPermutation.hasFloatAnswers();
426 if (verifierValidates) {
427 mJava->indent() << "// Extract the outputs.\n";
428 for (auto p : mAllInputsAndOutputs) {
429 if (p->isOutParameter) {
430 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700431 << "[i * " << p->vectorWidth << " + j];\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700432 }
433 }
434 mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
435 if (hasFloat) {
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800436 writeJavaCreateTarget();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700437 }
438 mJava->indent() << "String errorMessage = CoreMathVerifier."
439 << mJavaVerifierVerifyMethodName << "(args";
440 if (hasFloat) {
441 *mJava << ", target";
442 }
443 *mJava << ");\n";
444 mJava->indent() << "boolean valid = errorMessage == null;\n";
445 } else {
446 mJava->indent() << "// Figure out what the outputs should have been.\n";
447 if (hasFloat) {
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800448 writeJavaCreateTarget();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700449 }
450 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
451 if (hasFloat) {
452 *mJava << ", target";
453 }
454 *mJava << ");\n";
455 mJava->indent() << "// Validate the outputs.\n";
456 mJava->indent() << "boolean valid = true;\n";
457 for (auto p : mAllInputsAndOutputs) {
458 if (p->isOutParameter) {
459 writeJavaTestAndSetValid(*p, "", "[i * " + p->vectorWidth + " + j]");
460 }
461 }
462 }
463
464 mJava->indent() << "if (!valid)";
465 mJava->startBlock();
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700466 mJava->indent() << "if (!errorFound)";
467 mJava->startBlock();
468 mJava->indent() << "errorFound = true;\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700469
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700470 for (auto p : mAllInputsAndOutputs) {
471 if (p->isOutParameter) {
472 writeJavaAppendOutputToMessage(*p, "", "[i * " + p->vectorWidth + " + j]",
473 verifierValidates);
474 } else {
475 writeJavaAppendInputToMessage(*p, "args." + p->variableName);
476 }
477 }
478 if (verifierValidates) {
479 mJava->indent() << "message.append(errorMessage);\n";
480 }
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700481 mJava->indent() << "message.append(\"Errors at\");\n";
482 mJava->endBlock();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700483
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700484 mJava->indent() << "message.append(\" [\");\n";
485 mJava->indent() << "message.append(Integer.toString(i));\n";
486 mJava->indent() << "message.append(\", \");\n";
487 mJava->indent() << "message.append(Integer.toString(j));\n";
488 mJava->indent() << "message.append(\"]\");\n";
489
490 mJava->endBlock();
491 mJava->endBlock();
492 mJava->endBlock();
493
494 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700495 mJava->indentPlus()
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700496 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700497
498 mJava->endBlock();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700499 *mJava << "\n";
500}
501
502void PermutationWriter::writeJavaVerifyVectorMethod() const {
503 writeJavaVerifyMethodHeader();
504 mJava->startBlock();
505
506 for (auto p : mAllInputsAndOutputs) {
507 writeJavaArrayInitialization(*p);
508 }
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700509 mJava->indent() << "StringBuilder message = new StringBuilder();\n";
510 mJava->indent() << "boolean errorFound = false;\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700511 mJava->indent() << "for (int i = 0; i < INPUTSIZE; i++)";
512 mJava->startBlock();
513
514 mJava->indent() << mJavaArgumentsNClassName << " args = new " << mJavaArgumentsNClassName
515 << "();\n";
516
517 mJava->indent() << "// Create the appropriate sized arrays in args\n";
518 for (auto p : mAllInputsAndOutputs) {
519 if (p->mVectorSize != "1") {
520 string type = p->javaBaseType;
521 if (p->isOutParameter && p->isFloatType) {
522 type = "Target.Floaty";
523 }
524 mJava->indent() << "args." << p->variableName << " = new " << type << "["
525 << p->mVectorSize << "];\n";
526 }
527 }
528
529 mJava->indent() << "// Fill args with the input values\n";
530 for (auto p : mAllInputsAndOutputs) {
531 if (!p->isOutParameter) {
532 if (p->mVectorSize == "1") {
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700533 mJava->indent() << "args." << p->variableName << " = " << p->javaArrayName << "[i]"
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700534 << ";\n";
535 } else {
536 mJava->indent() << "for (int j = 0; j < " << p->mVectorSize << " ; j++)";
537 mJava->startBlock();
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700538 mJava->indent() << "args." << p->variableName << "[j] = "
539 << p->javaArrayName << "[i * " << p->vectorWidth << " + j]"
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700540 << ";\n";
541 mJava->endBlock();
542 }
543 }
544 }
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800545 writeJavaCreateTarget();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700546 mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
547 << "(args, target);\n\n";
548
549 mJava->indent() << "// Compare the expected outputs to the actual values returned by RS.\n";
550 mJava->indent() << "boolean valid = true;\n";
551 for (auto p : mAllInputsAndOutputs) {
552 if (p->isOutParameter) {
553 writeJavaVectorComparison(*p);
554 }
555 }
556
557 mJava->indent() << "if (!valid)";
558 mJava->startBlock();
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700559 mJava->indent() << "if (!errorFound)";
560 mJava->startBlock();
561 mJava->indent() << "errorFound = true;\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700562
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700563 for (auto p : mAllInputsAndOutputs) {
564 if (p->isOutParameter) {
565 writeJavaAppendVectorOutputToMessage(*p);
566 } else {
567 writeJavaAppendVectorInputToMessage(*p);
568 }
569 }
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700570 mJava->indent() << "message.append(\"Errors at\");\n";
571 mJava->endBlock();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700572
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700573 mJava->indent() << "message.append(\" [\");\n";
574 mJava->indent() << "message.append(Integer.toString(i));\n";
575 mJava->indent() << "message.append(\"]\");\n";
576
577 mJava->endBlock();
578 mJava->endBlock();
579
580 mJava->indent() << "assertFalse(\"Incorrect output for " << mJavaCheckMethodName << "\" +\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700581 mJava->indentPlus()
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700582 << "(relaxed ? \"_relaxed\" : \"\") + \":\\n\" + message.toString(), errorFound);\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700583
584 mJava->endBlock();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700585 *mJava << "\n";
586}
587
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800588
589void PermutationWriter::writeJavaCreateTarget() const {
590 string name = mPermutation.getName();
591
592 const char* functionType = "NORMAL";
593 size_t end = name.find('_');
594 if (end != string::npos) {
595 if (name.compare(0, end, "native") == 0) {
596 functionType = "NATIVE";
597 } else if (name.compare(0, end, "half") == 0) {
598 functionType = "HALF";
599 } else if (name.compare(0, end, "fast") == 0) {
600 functionType = "FAST";
601 }
602 }
603
604 string floatType = mReturnParam->specType;
605 const char* precisionStr = "";
606 if (floatType.compare("f16") == 0) {
607 precisionStr = "HALF";
608 } else if (floatType.compare("f32") == 0) {
609 precisionStr = "FLOAT";
610 } else if (floatType.compare("f64") == 0) {
611 precisionStr = "DOUBLE";
612 } else {
613 cerr << "Error. Unreachable. Return type is not floating point\n";
614 }
615
616 mJava->indent() << "Target target = new Target(Target.FunctionType." <<
617 functionType << ", Target.ReturnType." << precisionStr <<
618 ", relaxed);\n";
619}
620
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700621void PermutationWriter::writeJavaVerifyMethodHeader() const {
622 mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
623 for (auto p : mAllInputsAndOutputs) {
624 *mJava << "Allocation " << p->javaAllocName << ", ";
625 }
626 *mJava << "boolean relaxed)";
627}
628
629void PermutationWriter::writeJavaArrayInitialization(const ParameterDefinition& p) const {
630 mJava->indent() << p.javaBaseType << "[] " << p.javaArrayName << " = new " << p.javaBaseType
631 << "[INPUTSIZE * " << p.vectorWidth << "];\n";
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700632
633 /* For basic types, populate the array with values, to help understand failures. We have had
634 * bugs where the output buffer was all 0. We were not sure if there was a failed copy or
635 * the GPU driver was copying zeroes.
636 */
637 if (p.typeIndex >= 0) {
638 mJava->indent() << "Arrays.fill(" << p.javaArrayName << ", (" << TYPES[p.typeIndex].javaType
639 << ") 42);\n";
640 }
641
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700642 mJava->indent() << p.javaAllocName << ".copyTo(" << p.javaArrayName << ");\n";
643}
644
645void PermutationWriter::writeJavaTestAndSetValid(const ParameterDefinition& p,
646 const string& argsIndex,
647 const string& actualIndex) const {
648 writeJavaTestOneValue(p, argsIndex, actualIndex);
649 mJava->startBlock();
650 mJava->indent() << "valid = false;\n";
651 mJava->endBlock();
652}
653
654void PermutationWriter::writeJavaTestOneValue(const ParameterDefinition& p, const string& argsIndex,
655 const string& actualIndex) const {
656 mJava->indent() << "if (";
657 if (p.isFloatType) {
658 *mJava << "!args." << p.variableName << argsIndex << ".couldBe(" << p.javaArrayName
659 << actualIndex;
660 const string s = mPermutation.getPrecisionLimit();
661 if (!s.empty()) {
662 *mJava << ", " << s;
663 }
664 *mJava << ")";
665 } else {
666 *mJava << "args." << p.variableName << argsIndex << " != " << p.javaArrayName
667 << actualIndex;
668 }
669
670 if (p.undefinedIfOutIsNan && mReturnParam) {
671 *mJava << " && !args." << mReturnParam->variableName << argsIndex << ".isNaN()";
672 }
673 *mJava << ")";
674}
675
676void PermutationWriter::writeJavaVectorComparison(const ParameterDefinition& p) const {
677 if (p.mVectorSize == "1") {
678 writeJavaTestAndSetValid(p, "", "[i]");
679 } else {
680 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
681 mJava->startBlock();
682 writeJavaTestAndSetValid(p, "[j]", "[i * " + p.vectorWidth + " + j]");
683 mJava->endBlock();
684 }
685}
686
687void PermutationWriter::writeJavaAppendOutputToMessage(const ParameterDefinition& p,
688 const string& argsIndex,
689 const string& actualIndex,
690 bool verifierValidates) const {
691 if (verifierValidates) {
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700692 mJava->indent() << "message.append(\"Output " << p.variableName << ": \");\n";
693 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
694 << ");\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700695 writeJavaAppendNewLineToMessage();
696 } else {
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700697 mJava->indent() << "message.append(\"Expected output " << p.variableName << ": \");\n";
698 mJava->indent() << "appendVariableToMessage(message, args." << p.variableName << argsIndex
699 << ");\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700700 writeJavaAppendNewLineToMessage();
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700701
702 mJava->indent() << "message.append(\"Actual output " << p.variableName << ": \");\n";
703 mJava->indent() << "appendVariableToMessage(message, " << p.javaArrayName << actualIndex
704 << ");\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700705
706 writeJavaTestOneValue(p, argsIndex, actualIndex);
707 mJava->startBlock();
708 mJava->indent() << "message.append(\" FAIL\");\n";
709 mJava->endBlock();
710 writeJavaAppendNewLineToMessage();
711 }
712}
713
714void PermutationWriter::writeJavaAppendInputToMessage(const ParameterDefinition& p,
715 const string& actual) const {
Jean-Luc Brouillet49736b32015-04-14 14:23:48 -0700716 mJava->indent() << "message.append(\"Input " << p.variableName << ": \");\n";
717 mJava->indent() << "appendVariableToMessage(message, " << actual << ");\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700718 writeJavaAppendNewLineToMessage();
719}
720
721void PermutationWriter::writeJavaAppendNewLineToMessage() const {
722 mJava->indent() << "message.append(\"\\n\");\n";
723}
724
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700725void PermutationWriter::writeJavaAppendVectorInputToMessage(const ParameterDefinition& p) const {
726 if (p.mVectorSize == "1") {
727 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i]");
728 } else {
729 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
730 mJava->startBlock();
731 writeJavaAppendInputToMessage(p, p.javaArrayName + "[i * " + p.vectorWidth + " + j]");
732 mJava->endBlock();
733 }
734}
735
736void PermutationWriter::writeJavaAppendVectorOutputToMessage(const ParameterDefinition& p) const {
737 if (p.mVectorSize == "1") {
738 writeJavaAppendOutputToMessage(p, "", "[i]", false);
739 } else {
740 mJava->indent() << "for (int j = 0; j < " << p.mVectorSize << " ; j++)";
741 mJava->startBlock();
742 writeJavaAppendOutputToMessage(p, "[j]", "[i * " + p.vectorWidth + " + j]", false);
743 mJava->endBlock();
744 }
745}
746
747void PermutationWriter::writeJavaCallToRs(bool relaxed, bool generateCallToVerifier) const {
748 string script = "script";
749 if (relaxed) {
750 script += "Relaxed";
751 }
752
753 mJava->indent() << "try";
754 mJava->startBlock();
755
756 for (auto p : mAllInputsAndOutputs) {
757 if (p->isOutParameter) {
758 writeJavaOutputAllocationDefinition(*p);
759 }
760 }
761
762 for (auto p : mPermutation.getParams()) {
763 if (p != mFirstInputParam) {
764 mJava->indent() << script << ".set_" << p->rsAllocName << "(" << p->javaAllocName
765 << ");\n";
766 }
767 }
768
769 mJava->indent() << script << ".forEach_" << mRsKernelName << "(";
770 bool needComma = false;
771 if (mFirstInputParam) {
772 *mJava << mFirstInputParam->javaAllocName;
773 needComma = true;
774 }
775 if (mReturnParam) {
776 if (needComma) {
777 *mJava << ", ";
778 }
779 *mJava << mReturnParam->variableName << ");\n";
780 }
781
782 if (generateCallToVerifier) {
783 mJava->indent() << mJavaVerifyMethodName << "(";
784 for (auto p : mAllInputsAndOutputs) {
785 *mJava << p->variableName << ", ";
786 }
787
788 if (relaxed) {
789 *mJava << "true";
790 } else {
791 *mJava << "false";
792 }
793 *mJava << ");\n";
794 }
795 mJava->decreaseIndent();
796 mJava->indent() << "} catch (Exception e) {\n";
797 mJava->increaseIndent();
798 mJava->indent() << "throw new RSRuntimeException(\"RenderScript. Can't invoke forEach_"
799 << mRsKernelName << ": \" + e.toString());\n";
800 mJava->endBlock();
801}
802
803/* Write the section of the .rs file for this permutation.
804 *
805 * We communicate the extra input and output parameters via global allocations.
806 * For example, if we have a function that takes three arguments, two for input
807 * and one for output:
808 *
809 * start:
810 * name: gamn
811 * ret: float3
812 * arg: float3 a
813 * arg: int b
814 * arg: float3 *c
815 * end:
816 *
817 * We'll produce:
818 *
819 * rs_allocation gAllocInB;
820 * rs_allocation gAllocOutC;
821 *
822 * float3 __attribute__((kernel)) test_gamn_float3_int_float3(float3 inA, unsigned int x) {
823 * int inB;
824 * float3 outC;
825 * float2 out;
826 * inB = rsGetElementAt_int(gAllocInB, x);
827 * out = gamn(a, in_b, &outC);
828 * rsSetElementAt_float4(gAllocOutC, &outC, x);
829 * return out;
830 * }
831 *
832 * We avoid re-using x and y from the definition because these have reserved
833 * meanings in a .rs file.
834 */
835void PermutationWriter::writeRsSection(set<string>* rsAllocationsGenerated) const {
836 // Write the allocation declarations we'll need.
837 for (auto p : mPermutation.getParams()) {
838 // Don't need allocation for one input and one return value.
839 if (p != mFirstInputParam) {
840 writeRsAllocationDefinition(*p, rsAllocationsGenerated);
841 }
842 }
843 *mRs << "\n";
844
845 // Write the function header.
846 if (mReturnParam) {
847 *mRs << mReturnParam->rsType;
848 } else {
849 *mRs << "void";
850 }
851 *mRs << " __attribute__((kernel)) " << mRsKernelName;
852 *mRs << "(";
853 bool needComma = false;
854 if (mFirstInputParam) {
855 *mRs << mFirstInputParam->rsType << " " << mFirstInputParam->variableName;
856 needComma = true;
857 }
858 if (mPermutation.getOutputCount() > 1 || mPermutation.getInputCount() > 1) {
859 if (needComma) {
860 *mRs << ", ";
861 }
862 *mRs << "unsigned int x";
863 }
864 *mRs << ")";
865 mRs->startBlock();
866
867 // Write the local variable declarations and initializations.
868 for (auto p : mPermutation.getParams()) {
869 if (p == mFirstInputParam) {
870 continue;
871 }
872 mRs->indent() << p->rsType << " " << p->variableName;
873 if (p->isOutParameter) {
874 *mRs << " = 0;\n";
875 } else {
876 *mRs << " = rsGetElementAt_" << p->rsType << "(" << p->rsAllocName << ", x);\n";
877 }
878 }
879
880 // Write the function call.
881 if (mReturnParam) {
882 if (mPermutation.getOutputCount() > 1) {
883 mRs->indent() << mReturnParam->rsType << " " << mReturnParam->variableName << " = ";
884 } else {
885 mRs->indent() << "return ";
886 }
887 }
888 *mRs << mPermutation.getName() << "(";
889 needComma = false;
890 for (auto p : mPermutation.getParams()) {
891 if (needComma) {
892 *mRs << ", ";
893 }
894 if (p->isOutParameter) {
895 *mRs << "&";
896 }
897 *mRs << p->variableName;
898 needComma = true;
899 }
900 *mRs << ");\n";
901
902 if (mPermutation.getOutputCount() > 1) {
903 // Write setting the extra out parameters into the allocations.
904 for (auto p : mPermutation.getParams()) {
905 if (p->isOutParameter) {
906 mRs->indent() << "rsSetElementAt_" << p->rsType << "(" << p->rsAllocName << ", ";
907 // Check if we need to use '&' for this type of argument.
908 char lastChar = p->variableName.back();
909 if (lastChar >= '0' && lastChar <= '9') {
910 *mRs << "&";
911 }
912 *mRs << p->variableName << ", x);\n";
913 }
914 }
915 if (mReturnParam) {
916 mRs->indent() << "return " << mReturnParam->variableName << ";\n";
917 }
918 }
919 mRs->endBlock();
920}
921
922void PermutationWriter::writeRsAllocationDefinition(const ParameterDefinition& param,
923 set<string>* rsAllocationsGenerated) const {
924 if (!testAndSet(param.rsAllocName, rsAllocationsGenerated)) {
925 *mRs << "rs_allocation " << param.rsAllocName << ";\n";
926 }
927}
928
929// Open the mJavaFile and writes the header.
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700930static bool startJavaFile(GeneratedFile* file, const Function& function, const string& directory,
931 const string& testName, const string& relaxedTestName) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700932 const string fileName = testName + ".java";
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700933 if (!file->start(directory, fileName)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700934 return false;
935 }
936 file->writeNotices();
937
938 *file << "package android.renderscript.cts;\n\n";
939
940 *file << "import android.renderscript.Allocation;\n";
941 *file << "import android.renderscript.RSRuntimeException;\n";
Pirama Arumuga Nainar93abc2d2016-02-16 15:11:29 -0800942 *file << "import android.renderscript.Element;\n";
943 *file << "import android.renderscript.cts.Target;\n\n";
Jean-Luc Brouillet0c905c82015-07-23 14:03:19 -0700944 *file << "import java.util.Arrays;\n\n";
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700945
946 *file << "public class " << testName << " extends RSBaseCompute";
947 file->startBlock(); // The corresponding endBlock() is in finishJavaFile()
948 *file << "\n";
949
950 file->indent() << "private ScriptC_" << testName << " script;\n";
951 file->indent() << "private ScriptC_" << relaxedTestName << " scriptRelaxed;\n\n";
952
953 file->indent() << "@Override\n";
954 file->indent() << "protected void setUp() throws Exception";
955 file->startBlock();
956
957 file->indent() << "super.setUp();\n";
958 file->indent() << "script = new ScriptC_" << testName << "(mRS);\n";
959 file->indent() << "scriptRelaxed = new ScriptC_" << relaxedTestName << "(mRS);\n";
960
961 file->endBlock();
962 *file << "\n";
963 return true;
964}
965
966// Write the test method that calls all the generated Check methods.
967static void finishJavaFile(GeneratedFile* file, const Function& function,
968 const vector<string>& javaCheckMethods) {
969 file->indent() << "public void test" << function.getCapitalizedName() << "()";
970 file->startBlock();
971 for (auto m : javaCheckMethods) {
972 file->indent() << m << "();\n";
973 }
974 file->endBlock();
975
976 file->endBlock();
977}
978
979// Open the script file and write its header.
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700980static bool startRsFile(GeneratedFile* file, const Function& function, const string& directory,
981 const string& testName) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700982 string fileName = testName + ".rs";
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700983 if (!file->start(directory, fileName)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700984 return false;
985 }
986 file->writeNotices();
987
988 *file << "#pragma version(1)\n";
989 *file << "#pragma rs java_package_name(android.renderscript.cts)\n\n";
990 return true;
991}
992
993// Write the entire *Relaxed.rs test file, as it only depends on the name.
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700994static bool writeRelaxedRsFile(const Function& function, const string& directory,
995 const string& testName, const string& relaxedTestName) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -0700996 string name = relaxedTestName + ".rs";
997
998 GeneratedFile file;
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -0700999 if (!file.start(directory, name)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001000 return false;
1001 }
1002 file.writeNotices();
1003
1004 file << "#include \"" << testName << ".rs\"\n";
1005 file << "#pragma rs_fp_relaxed\n";
1006 file.close();
1007 return true;
1008}
1009
1010/* Write the .java and the two .rs test files. versionOfTestFiles is used to restrict which API
1011 * to test.
1012 */
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -07001013static bool writeTestFilesForFunction(const Function& function, const string& directory,
Yang Ni12398d82015-09-18 14:57:07 -07001014 unsigned int versionOfTestFiles) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001015 // Avoid creating empty files if we're not testing this function.
1016 if (!needTestFiles(function, versionOfTestFiles)) {
1017 return true;
1018 }
1019
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -07001020 const string testName = "Test" + function.getCapitalizedName();
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001021 const string relaxedTestName = testName + "Relaxed";
1022
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -07001023 if (!writeRelaxedRsFile(function, directory, testName, relaxedTestName)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001024 return false;
1025 }
1026
1027 GeneratedFile rsFile; // The Renderscript test file we're generating.
1028 GeneratedFile javaFile; // The Jave test file we're generating.
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -07001029 if (!startRsFile(&rsFile, function, directory, testName)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001030 return false;
1031 }
1032
Jean-Luc Brouillet62e09932015-03-22 11:14:07 -07001033 if (!startJavaFile(&javaFile, function, directory, testName, relaxedTestName)) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001034 return false;
1035 }
1036
1037 /* We keep track of the allocations generated in the .rs file and the argument classes defined
1038 * in the Java file, as we share these between the functions created for each specification.
1039 */
1040 set<string> rsAllocationsGenerated;
1041 set<string> javaGeneratedArgumentClasses;
1042 // Lines of Java code to invoke the check methods.
1043 vector<string> javaCheckMethods;
1044
1045 for (auto spec : function.getSpecifications()) {
1046 if (spec->hasTests(versionOfTestFiles)) {
1047 for (auto permutation : spec->getPermutations()) {
1048 PermutationWriter w(*permutation, &rsFile, &javaFile);
1049 w.writeRsSection(&rsAllocationsGenerated);
1050 w.writeJavaSection(&javaGeneratedArgumentClasses);
1051
1052 // Store the check method to be called.
1053 javaCheckMethods.push_back(w.getJavaCheckMethodName());
1054 }
1055 }
1056 }
1057
1058 finishJavaFile(&javaFile, function, javaCheckMethods);
1059 // There's no work to wrap-up in the .rs file.
1060
1061 rsFile.close();
1062 javaFile.close();
1063 return true;
1064}
1065
Yang Ni12398d82015-09-18 14:57:07 -07001066bool generateTestFiles(const string& directory, unsigned int versionOfTestFiles) {
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001067 bool success = true;
Jean-Luc Brouillet7c078542015-03-23 16:16:08 -07001068 for (auto f : systemSpecification.getFunctions()) {
1069 if (!writeTestFilesForFunction(*f.second, directory, versionOfTestFiles)) {
1070 success = false;
Jean-Luc Brouilletc5184e22015-03-13 13:51:24 -07001071 }
1072 }
1073 return success;
1074}