Check kernels before fusing
b/21958851
Kernel fusion for a Script Group expects kernels to be chained up
via outputs (return values) and first arguments.
Check this condition during fusion. Otherwise, BCC may crash on
broken invariants.
Change-Id: I013558c77dc3f79d6e42986121927dd6c695f27e
diff --git a/lib/Renderscript/RSScriptGroupFusion.cpp b/lib/Renderscript/RSScriptGroupFusion.cpp
index e16445c..c76f869 100644
--- a/lib/Renderscript/RSScriptGroupFusion.cpp
+++ b/lib/Renderscript/RSScriptGroupFusion.cpp
@@ -76,7 +76,11 @@
return function;
}
-// TODO: Handle the context argument
+// The whitelist of supported signature bits. Context or user data arguments are
+// not currently supported in kernel fusion. To support them or any new kinds of
+// arguments in the future, it requires not only listing the signature bits here,
+// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
+// getFusedFuncType(), and fuseKernels() functions below.
constexpr uint32_t ExpectedSignatureBits =
bcinfo::MD_SIG_In |
bcinfo::MD_SIG_Out |
@@ -177,10 +181,10 @@
Module* mergedModule) {
bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
- uint32_t signature;
+ uint32_t fusedFunctionSignature;
llvm::FunctionType* fusedType =
- getFusedFuncType(Context, sources, slots, mergedModule, &signature);
+ getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
if (fusedType == nullptr) {
return false;
@@ -197,60 +201,80 @@
Function::arg_iterator argIter = fusedKernel->arg_begin();
llvm::Value* dataElement = nullptr;
- if (bcinfo::MetadataExtractor::hasForEachSignatureIn(signature)) {
+ if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
dataElement = argIter++;
dataElement->setName("DataIn");
}
llvm::Value* X = nullptr;
- if (bcinfo::MetadataExtractor::hasForEachSignatureX(signature)) {
- X = argIter++;
- X->setName("x");
+ if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
+ X = argIter++;
+ X->setName("x");
}
llvm::Value* Y = nullptr;
- if (bcinfo::MetadataExtractor::hasForEachSignatureY(signature)) {
- Y = argIter++;
- Y->setName("y");
+ if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
+ Y = argIter++;
+ Y->setName("y");
}
llvm::Value* Z = nullptr;
- if (bcinfo::MetadataExtractor::hasForEachSignatureZ(signature)) {
- Z = argIter++;
- Z->setName("z");
+ if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
+ Z = argIter++;
+ Z->setName("z");
}
auto slotIter = slots.begin();
for (const Source* source : sources) {
int slot = *slotIter++;
- uint32_t signature;
- const Function* function = getFunction(mergedModule, source, slot, &signature);
+ uint32_t inputFunctionSignature;
+ const Function* inputFunction =
+ getFunction(mergedModule, source, slot, &inputFunctionSignature);
+ if (inputFunction == nullptr) {
+ return false;
+ }
- if (function == nullptr) {
+ // Don't try to fuse a non-kernel
+ if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
return false;
}
std::vector<llvm::Value*> args;
- if (dataElement != nullptr) {
+
+ if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
+ if (dataElement == nullptr) {
+ return false;
+ }
+
+ const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
+ llvm::Type* firstArgType = funcTy->getParamType(0);
+
+ if (!dataElement->getType()->canLosslesslyBitCastTo(firstArgType)) {
+ return false;
+ }
+
args.push_back(dataElement);
+ } else {
+ // Only the first kernel in a batch is allowed to have no input
+ if (slotIter != slots.begin()) {
+ return false;
+ }
}
- // TODO: Handle the context argument
-
- if (bcinfo::MetadataExtractor::hasForEachSignatureX(signature)) {
+ if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
args.push_back(X);
}
- if (bcinfo::MetadataExtractor::hasForEachSignatureY(signature)) {
+ if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
args.push_back(Y);
}
- if (bcinfo::MetadataExtractor::hasForEachSignatureZ(signature)) {
+ if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
args.push_back(Z);
}
- dataElement = builder.CreateCall((llvm::Value*)function, args);
+ dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
}
if (fusedKernel->getReturnType()->isVoidTy()) {
@@ -269,7 +293,7 @@
llvm::NamedMDNode* ExportForEachMD =
mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
- llvm::utostr_32(signature));
+ llvm::utostr_32(fusedFunctionSignature));
llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
ExportForEachMD->addOperand(sigMDNode);