Support for pass-by-value kernels.

This change allows the JB-mr1+ target API to declare compute kernels
using "__attribute__((kernel))". This disables the use of pointers in the
function signature and forces any output to be explicitly returned and
input to be passed only by value. We still allow the user to add x, y
coordinates if they want them.

Bug: 7166741

Change-Id: I1407fceefb11c7d6c17221ca156cfce443c2b218
diff --git a/slang_rs_export_foreach.cpp b/slang_rs_export_foreach.cpp
index 484ab46..5754b41 100644
--- a/slang_rs_export_foreach.cpp
+++ b/slang_rs_export_foreach.cpp
@@ -50,6 +50,7 @@
 
 }  // namespace
 
+
 // This function takes care of additional validation and construction of
 // parameters related to forEach_* reflection.
 bool RSExportForEach::validateAndConstructParams(
@@ -60,7 +61,6 @@
   clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
 
   numParams = FD->getNumParams();
-  slangAssert(numParams > 0);
 
   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
     if (!isRootRSFunc(FD)) {
@@ -76,8 +76,19 @@
     }
   }
 
-  // Compute kernel functions are required to return a void type for now
-  if (FD->getResultType().getCanonicalType() != C.VoidTy) {
+  mResultType = FD->getResultType().getCanonicalType();
+  // Compute kernel functions are required to return a void type or
+  // be marked explicitly as a kernel. In the case of
+  // "__attribute__((kernel))", we handle validation differently.
+  if (FD->hasAttr<clang::KernelAttr>()) {
+    return validateAndConstructKernelParams(Context, FD);
+  }
+
+  // If numParams is 0, we already marked this as a graphics root().
+  slangAssert(numParams > 0);
+
+  // Compute kernel functions of this type are required to return a void type.
+  if (mResultType != C.VoidTy) {
     DiagEngine->Report(
       clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
       DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
@@ -216,6 +227,166 @@
   return valid;
 }
 
+
+bool RSExportForEach::validateAndConstructKernelParams(RSContext *Context,
+    const clang::FunctionDecl *FD) {
+  slangAssert(Context && FD);
+  bool valid = true;
+  clang::ASTContext &C = Context->getASTContext();
+  clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
+
+  if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
+    DiagEngine->Report(
+      clang::FullSourceLoc(FD->getLocation(),
+                           DiagEngine->getSourceManager()),
+      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                  "Compute kernel %0() targeting SDK levels "
+                                  "%1-%2 may not use pass-by-value with "
+                                  "__attribute__((kernel))"))
+      << FD->getName() << SLANG_MINIMUM_TARGET_API
+      << (SLANG_JB_MR1_TARGET_API - 1);
+    return false;
+  }
+
+  // Denote that we are indeed a pass-by-value kernel.
+  mKernel = true;
+
+  if (mResultType != C.VoidTy) {
+    mReturn = true;
+  }
+
+  if (mResultType->isPointerType()) {
+    DiagEngine->Report(
+      clang::FullSourceLoc(FD->getTypeSpecStartLoc(),
+                           DiagEngine->getSourceManager()),
+      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                  "Compute kernel %0() cannot return a "
+                                  "pointer type: '%1'"))
+      << FD->getName() << mResultType.getAsString();
+    valid = false;
+  }
+
+  // Validate remaining parameter types
+  // TODO(all): Add support for LOD/face when we have them
+
+  size_t i = 0;
+  const clang::ParmVarDecl *PVD = NULL;
+  clang::QualType QT;
+
+  if (i < numParams) {
+    PVD = FD->getParamDecl(i);
+    QT = PVD->getType().getCanonicalType();
+
+    if (QT->isPointerType()) {
+      DiagEngine->Report(
+        clang::FullSourceLoc(PVD->getLocation(),
+                             DiagEngine->getSourceManager()),
+        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                    "Compute kernel %0() cannot have "
+                                    "parameter '%1' of pointer type: '%2'"))
+        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
+      valid = false;
+    } else if (QT.getUnqualifiedType() == C.UnsignedIntTy) {
+      // First parameter is either input or x, y (iff it is uint32_t).
+      llvm::StringRef ParamName = PVD->getName();
+      if (ParamName.equals("x")) {
+        mX = PVD;
+      } else if (ParamName.equals("y")) {
+        mY = PVD;
+      } else {
+        mIn = PVD;
+      }
+    } else {
+      mIn = PVD;
+    }
+
+    i++;  // advance parameter pointer
+  }
+
+  // Check that we have at least one allocation to use for dimensions.
+  if (valid && !mIn && !mReturn) {
+    DiagEngine->Report(
+      clang::FullSourceLoc(FD->getLocation(),
+                           DiagEngine->getSourceManager()),
+      DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                  "Compute kernel %0() must have at least one "
+                                  "input parameter or a non-void return "
+                                  "type")) << FD->getName();
+    valid = false;
+  }
+
+  // TODO: Abstract this block away, since it is duplicate code.
+  while (i < numParams) {
+    PVD = FD->getParamDecl(i);
+    QT = PVD->getType().getCanonicalType();
+
+    if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
+      DiagEngine->Report(
+        clang::FullSourceLoc(PVD->getLocation(),
+                             DiagEngine->getSourceManager()),
+        DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                    "Unexpected kernel %0() parameter '%1' "
+                                    "of type '%2'"))
+        << FD->getName() << PVD->getName() << PVD->getType().getAsString();
+      valid = false;
+    } else {
+      llvm::StringRef ParamName = PVD->getName();
+      if (ParamName.equals("x")) {
+        if (mX) {
+          ReportNameError(DiagEngine, PVD);
+          valid = false;
+        } else if (mY) {
+          // Can't go back to X after skipping Y
+          ReportNameError(DiagEngine, PVD);
+          valid = false;
+        } else {
+          mX = PVD;
+        }
+      } else if (ParamName.equals("y")) {
+        if (mY) {
+          ReportNameError(DiagEngine, PVD);
+          valid = false;
+        } else {
+          mY = PVD;
+        }
+      } else {
+        if (!mX && !mY) {
+          mX = PVD;
+        } else if (!mY) {
+          mY = PVD;
+        } else {
+          DiagEngine->Report(
+            clang::FullSourceLoc(PVD->getLocation(),
+                                 DiagEngine->getSourceManager()),
+            DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                        "Unexpected kernel %0() parameter '%1' "
+                                        "of type '%2'"))
+            << FD->getName() << PVD->getName() << PVD->getType().getAsString();
+          valid = false;
+        }
+      }
+    }
+
+    i++;  // advance parameter pointer
+  }
+
+  mSignatureMetadata = 0;
+  if (valid) {
+    // Set up the bitwise metadata encoding for runtime argument passing.
+    mSignatureMetadata |= (mIn ?       0x01 : 0);
+    slangAssert(mOut == NULL);
+    mSignatureMetadata |= (mReturn ?   0x02 : 0);
+    slangAssert(mUsrData == NULL);
+    mSignatureMetadata |= (mUsrData ?  0x04 : 0);
+    mSignatureMetadata |= (mX ?        0x08 : 0);
+    mSignatureMetadata |= (mY ?        0x10 : 0);
+    mSignatureMetadata |= (mKernel ?   0x20 : 0);  // pass-by-value
+  }
+
+  return valid;
+}
+
+
 RSExportForEach *RSExportForEach::Create(RSContext *Context,
                                          const clang::FunctionDecl *FD) {
   slangAssert(Context && FD);
@@ -293,9 +464,16 @@
   if (FE->mIn) {
     const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
     FE->mInType = RSExportType::Create(Context, T);
+    if (FE->mKernel) {
+      slangAssert(FE->mInType);
+    }
   }
 
-  if (FE->mOut) {
+  if (FE->mKernel && FE->mReturn) {
+    const clang::Type *T = FE->mResultType.getTypePtr();
+    FE->mOutType = RSExportType::Create(Context, T);
+    slangAssert(FE->mOutType);
+  } else if (FE->mOut) {
     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
     FE->mOutType = RSExportType::Create(Context, T);
   }
@@ -313,6 +491,10 @@
 
 bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
                                            const clang::FunctionDecl *FD) {
+  if (FD->hasAttr<clang::KernelAttr>()) {
+    return false;
+  }
+
   if (!isRootRSFunc(FD)) {
     return false;
   }
@@ -335,6 +517,11 @@
 
 bool RSExportForEach::isRSForEachFunc(int targetAPI,
     const clang::FunctionDecl *FD) {
+  // Anything tagged as a kernel is definitely used with ForEach.
+  if (FD->hasAttr<clang::KernelAttr>()) {
+    return true;
+  }
+
   if (isGraphicsRootRSFunc(targetAPI, FD)) {
     return false;
   }
diff --git a/slang_rs_export_foreach.h b/slang_rs_export_foreach.h
index c8b1dd0..9bc185e 100644
--- a/slang_rs_export_foreach.h
+++ b/slang_rs_export_foreach.h
@@ -53,6 +53,10 @@
   const clang::ParmVarDecl *mZ;
   const clang::ParmVarDecl *mAr;
 
+  clang::QualType mResultType;  // return type (if present).
+  bool mReturn;  // does this kernel have a return type?
+  bool mKernel;  // is this a pass-by-value kernel?
+
   bool mDummyRoot;
 
   // TODO(all): Add support for LOD/face when we have them
@@ -60,14 +64,18 @@
     : RSExportable(Context, RSExportable::EX_FOREACH),
       mName(Name.data(), Name.size()), mParamPacketType(NULL), mInType(NULL),
       mOutType(NULL), numParams(0), mSignatureMetadata(0),
-      mIn(NULL), mOut(NULL), mUsrData(NULL),
-      mX(NULL), mY(NULL), mZ(NULL), mAr(NULL), mDummyRoot(false) {
+      mIn(NULL), mOut(NULL), mUsrData(NULL), mX(NULL), mY(NULL), mZ(NULL),
+      mAr(NULL), mResultType(clang::QualType()), mReturn(false),
+      mKernel(false), mDummyRoot(false) {
     return;
   }
 
   bool validateAndConstructParams(RSContext *Context,
                                   const clang::FunctionDecl *FD);
 
+  bool validateAndConstructKernelParams(RSContext *Context,
+                                        const clang::FunctionDecl *FD);
+
  public:
   static RSExportForEach *Create(RSContext *Context,
                                  const clang::FunctionDecl *FD);
@@ -94,6 +102,10 @@
     return (mUsrData != NULL);
   }
 
+  inline bool hasReturn() const {
+    return mReturn;
+  }
+
   inline const RSExportType *getInType() const {
     return mInType;
   }
diff --git a/slang_rs_reflection.cpp b/slang_rs_reflection.cpp
index c781241..cd088a4 100644
--- a/slang_rs_reflection.cpp
+++ b/slang_rs_reflection.cpp
@@ -648,11 +648,11 @@
   // forEach_*()
   Context::ArgTy Args;
 
-  slangAssert(EF->getNumParameters() > 0);
+  slangAssert(EF->getNumParameters() > 0 || EF->hasReturn());
 
   if (EF->hasIn())
     Args.push_back(std::make_pair("Allocation", "ain"));
-  if (EF->hasOut())
+  if (EF->hasOut() || EF->hasReturn())
     Args.push_back(std::make_pair("Allocation", "aout"));
 
   const RSExportRecordType *ERT = EF->getParamPacketType();
@@ -682,7 +682,7 @@
     genTypeCheck(C, OET, "aout");
   }
 
-  if (EF->hasIn() && EF->hasOut()) {
+  if (EF->hasIn() && (EF->hasOut() || EF->hasReturn())) {
     C.indent() << "// Verify dimensions" << std::endl;
     C.indent() << "Type tIn = ain.getType();" << std::endl;
     C.indent() << "Type tOut = aout.getType();" << std::endl;
@@ -711,7 +711,7 @@
   else
     C.out() << ", null";
 
-  if (EF->hasOut())
+  if (EF->hasOut() || EF->hasReturn())
     C.out() << ", aout";
   else
     C.out() << ", null";
@@ -730,9 +730,13 @@
 void RSReflection::genTypeInstanceFromPointer(Context &C,
                                               const RSExportType *ET) {
   if (ET->getClass() == RSExportType::ExportClassPointer) {
+    // For pointer parameters to original forEach kernels.
     const RSExportPointerType *EPT =
         static_cast<const RSExportPointerType*>(ET);
     genTypeInstance(C, EPT->getPointeeType());
+  } else {
+    // For handling pass-by-value kernel parameters.
+    genTypeInstance(C, ET);
   }
 }
 
diff --git a/tests/F_kernel_16/kernel_16.rs b/tests/F_kernel_16/kernel_16.rs
new file mode 100644
index 0000000..56007a2
--- /dev/null
+++ b/tests/F_kernel_16/kernel_16.rs
@@ -0,0 +1,7 @@
+// -target-api 16
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+void __attribute__((kernel)) root(int i) {
+}
+
diff --git a/tests/F_kernel_16/stderr.txt.expect b/tests/F_kernel_16/stderr.txt.expect
new file mode 100644
index 0000000..b8a96ec
--- /dev/null
+++ b/tests/F_kernel_16/stderr.txt.expect
@@ -0,0 +1 @@
+kernel_16.rs:5:30: error: Compute kernel root() targeting SDK levels 11-16 may not use pass-by-value with __attribute__((kernel))
diff --git a/tests/F_kernel_16/stdout.txt.expect b/tests/F_kernel_16/stdout.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/F_kernel_16/stdout.txt.expect
diff --git a/tests/F_kernel_badsig/kernel_badsig.rs b/tests/F_kernel_badsig/kernel_badsig.rs
new file mode 100644
index 0000000..db7652d
--- /dev/null
+++ b/tests/F_kernel_badsig/kernel_badsig.rs
@@ -0,0 +1,6 @@
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+void __attribute__((kernel)) root(uint32_t x) {
+}
+
diff --git a/tests/F_kernel_badsig/stderr.txt.expect b/tests/F_kernel_badsig/stderr.txt.expect
new file mode 100644
index 0000000..9f2b7ac
--- /dev/null
+++ b/tests/F_kernel_badsig/stderr.txt.expect
@@ -0,0 +1 @@
+kernel_badsig.rs:4:30: error: Compute kernel root() must have at least one input parameter or a non-void return type
diff --git a/tests/F_kernel_badsig/stdout.txt.expect b/tests/F_kernel_badsig/stdout.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/F_kernel_badsig/stdout.txt.expect
diff --git a/tests/F_kernel_noattr/kernel_noattr.rs b/tests/F_kernel_noattr/kernel_noattr.rs
new file mode 100644
index 0000000..e46888b
--- /dev/null
+++ b/tests/F_kernel_noattr/kernel_noattr.rs
@@ -0,0 +1,18 @@
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+int root(uint32_t ain) {
+  return 0;
+}
+
+void in_only(uint32_t ain) {
+}
+
+int out_only() {
+  return 0;
+}
+
+int everything(uint32_t ain, uint32_t x, uint32_t y) {
+  return 0;
+}
+
diff --git a/tests/F_kernel_noattr/stderr.txt.expect b/tests/F_kernel_noattr/stderr.txt.expect
new file mode 100644
index 0000000..d8e054f
--- /dev/null
+++ b/tests/F_kernel_noattr/stderr.txt.expect
@@ -0,0 +1,4 @@
+kernel_noattr.rs:4:5: error: Compute kernel root() is required to return a void type
+kernel_noattr.rs:4:5: error: Compute kernel root() must have at least one parameter for in or out
+kernel_noattr.rs:11:5: error: invokable non-static functions are required to return void
+kernel_noattr.rs:15:5: error: invokable non-static functions are required to return void
diff --git a/tests/F_kernel_noattr/stdout.txt.expect b/tests/F_kernel_noattr/stdout.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/F_kernel_noattr/stdout.txt.expect
diff --git a/tests/F_kernel_ptr_param/kernel_ptr_param.rs b/tests/F_kernel_ptr_param/kernel_ptr_param.rs
new file mode 100644
index 0000000..2467793
--- /dev/null
+++ b/tests/F_kernel_ptr_param/kernel_ptr_param.rs
@@ -0,0 +1,6 @@
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+void __attribute__((kernel)) root(int *i) {
+}
+
diff --git a/tests/F_kernel_ptr_param/stderr.txt.expect b/tests/F_kernel_ptr_param/stderr.txt.expect
new file mode 100644
index 0000000..36f864c
--- /dev/null
+++ b/tests/F_kernel_ptr_param/stderr.txt.expect
@@ -0,0 +1 @@
+kernel_ptr_param.rs:4:40: error: Compute kernel root() cannot have parameter 'i' of pointer type: 'int *'
diff --git a/tests/F_kernel_ptr_param/stdout.txt.expect b/tests/F_kernel_ptr_param/stdout.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/F_kernel_ptr_param/stdout.txt.expect
diff --git a/tests/F_kernel_ptr_ret_val/kernel_ptr_ret_val.rs b/tests/F_kernel_ptr_ret_val/kernel_ptr_ret_val.rs
new file mode 100644
index 0000000..f964b31
--- /dev/null
+++ b/tests/F_kernel_ptr_ret_val/kernel_ptr_ret_val.rs
@@ -0,0 +1,7 @@
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+int * __attribute__((kernel)) root() {
+  return NULL;
+}
+
diff --git a/tests/F_kernel_ptr_ret_val/stderr.txt.expect b/tests/F_kernel_ptr_ret_val/stderr.txt.expect
new file mode 100644
index 0000000..7567c1d
--- /dev/null
+++ b/tests/F_kernel_ptr_ret_val/stderr.txt.expect
@@ -0,0 +1 @@
+kernel_ptr_ret_val.rs:4:1: error: Compute kernel root() cannot return a pointer type: 'int *'
diff --git a/tests/F_kernel_ptr_ret_val/stdout.txt.expect b/tests/F_kernel_ptr_ret_val/stdout.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/F_kernel_ptr_ret_val/stdout.txt.expect
diff --git a/tests/P_kernel/kernel.rs b/tests/P_kernel/kernel.rs
new file mode 100644
index 0000000..d8f0c4d
--- /dev/null
+++ b/tests/P_kernel/kernel.rs
@@ -0,0 +1,18 @@
+#pragma version(1)
+#pragma rs java_package_name(foo)
+
+int __attribute__((kernel)) root(uint32_t ain) {
+  return 0;
+}
+
+void __attribute__((kernel)) in_only(uint32_t ain) {
+}
+
+int __attribute__((kernel)) out_only() {
+  return 0;
+}
+
+int __attribute__((kernel)) everything(uint32_t ain, uint32_t x, uint32_t y) {
+  return 0;
+}
+
diff --git a/tests/P_kernel/stderr.txt.expect b/tests/P_kernel/stderr.txt.expect
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/P_kernel/stderr.txt.expect
diff --git a/tests/P_kernel/stdout.txt.expect b/tests/P_kernel/stdout.txt.expect
new file mode 100644
index 0000000..a58134c
--- /dev/null
+++ b/tests/P_kernel/stdout.txt.expect
@@ -0,0 +1 @@
+Generating ScriptC_kernel.java ...