Check kernels before fusing

b/21958851

Kernel fusion for a Script Group expects kernels to be chained up
via outputs (return values) and first arguments.
Check this condition during fusion. Otherwise, BCC may crash on
broken invariants.

Change-Id: I013558c77dc3f79d6e42986121927dd6c695f27e
(cherry picked from commit 8c12d615b4ed4b1d782722a125dd1d43bc44a71b)
diff --git a/include/bcc/Renderscript/RSScriptGroupFusion.h b/include/bcc/Renderscript/RSScriptGroupFusion.h
index 51e983a..c173ac4 100644
--- a/include/bcc/Renderscript/RSScriptGroupFusion.h
+++ b/include/bcc/Renderscript/RSScriptGroupFusion.h
@@ -35,7 +35,9 @@
 /// @param sources The Sources containing the kernels.
 /// @param slots The slots where the kernels are located.
 /// @param fusedName
-/// @return True, if kernels are successfully merged. False, otherwise.
+/// @return True, if kernels are successfully fused. False, otherwise. It's up to
+/// the caller on how to deal with unsuccessful fusion. A script group can
+/// execute with either fused kernels or individual kernels.
 bool fuseKernels(BCCContext& Context,
                  const std::vector<Source *>& sources,
                  const std::vector<int>& slots,
diff --git a/lib/Renderscript/RSScriptGroupFusion.cpp b/lib/Renderscript/RSScriptGroupFusion.cpp
index e16445c..c76f869 100644
--- a/lib/Renderscript/RSScriptGroupFusion.cpp
+++ b/lib/Renderscript/RSScriptGroupFusion.cpp
@@ -76,7 +76,11 @@
   return function;
 }
 
-// TODO: Handle the context argument
+// The whitelist of supported signature bits. Context or user data arguments are
+// not currently supported in kernel fusion. To support them or any new kinds of
+// arguments in the future, it requires not only listing the signature bits here,
+// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
+// getFusedFuncType(), and fuseKernels() functions below.
 constexpr uint32_t ExpectedSignatureBits =
         bcinfo::MD_SIG_In |
         bcinfo::MD_SIG_Out |
@@ -177,10 +181,10 @@
                  Module* mergedModule) {
   bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
 
-  uint32_t signature;
+  uint32_t fusedFunctionSignature;
 
   llvm::FunctionType* fusedType =
-          getFusedFuncType(Context, sources, slots, mergedModule, &signature);
+          getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
 
   if (fusedType == nullptr) {
     return false;
@@ -197,60 +201,80 @@
   Function::arg_iterator argIter = fusedKernel->arg_begin();
 
   llvm::Value* dataElement = nullptr;
-  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(signature)) {
+  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
     dataElement = argIter++;
     dataElement->setName("DataIn");
   }
 
   llvm::Value* X = nullptr;
-  if (bcinfo::MetadataExtractor::hasForEachSignatureX(signature)) {
-      X = argIter++;
-      X->setName("x");
+  if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
+    X = argIter++;
+    X->setName("x");
   }
 
   llvm::Value* Y = nullptr;
-  if (bcinfo::MetadataExtractor::hasForEachSignatureY(signature)) {
-      Y = argIter++;
-      Y->setName("y");
+  if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
+    Y = argIter++;
+    Y->setName("y");
   }
 
   llvm::Value* Z = nullptr;
-  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(signature)) {
-      Z = argIter++;
-      Z->setName("z");
+  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
+    Z = argIter++;
+    Z->setName("z");
   }
 
   auto slotIter = slots.begin();
   for (const Source* source : sources) {
     int slot = *slotIter++;
 
-    uint32_t signature;
-    const Function* function = getFunction(mergedModule, source, slot, &signature);
+    uint32_t inputFunctionSignature;
+    const Function* inputFunction =
+            getFunction(mergedModule, source, slot, &inputFunctionSignature);
+    if (inputFunction == nullptr) {
+      return false;
+    }
 
-    if (function == nullptr) {
+    // Don't try to fuse a non-kernel
+    if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
       return false;
     }
 
     std::vector<llvm::Value*> args;
-    if (dataElement != nullptr) {
+
+    if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
+      if (dataElement == nullptr) {
+        return false;
+      }
+
+      const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
+      llvm::Type* firstArgType = funcTy->getParamType(0);
+
+      if (!dataElement->getType()->canLosslesslyBitCastTo(firstArgType)) {
+        return false;
+      }
+
       args.push_back(dataElement);
+    } else {
+      // Only the first kernel in a batch is allowed to have no input
+      if (slotIter != slots.begin()) {
+        return false;
+      }
     }
 
-    // TODO: Handle the context argument
-
-    if (bcinfo::MetadataExtractor::hasForEachSignatureX(signature)) {
+    if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
       args.push_back(X);
     }
 
-    if (bcinfo::MetadataExtractor::hasForEachSignatureY(signature)) {
+    if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
       args.push_back(Y);
     }
 
-    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(signature)) {
+    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
       args.push_back(Z);
     }
 
-    dataElement = builder.CreateCall((llvm::Value*)function, args);
+    dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
   }
 
   if (fusedKernel->getReturnType()->isVoidTy()) {
@@ -269,7 +293,7 @@
   llvm::NamedMDNode* ExportForEachMD =
     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
   llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
-                                                 llvm::utostr_32(signature));
+                                                 llvm::utostr_32(fusedFunctionSignature));
   llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
   ExportForEachMD->addOperand(sigMDNode);