Error on more illegal kernel argument types for OpenCL

bool, half, pointers and structs / unions containing any
of these are not allowed. Does not yet reject size_t and
related integer types that are also disallowed.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@186908 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaDecl.cpp b/lib/Sema/SemaDecl.cpp
index 8b7c472..aca71b0 100644
--- a/lib/Sema/SemaDecl.cpp
+++ b/lib/Sema/SemaDecl.cpp
@@ -6069,6 +6069,173 @@
   }
 }
 
+enum OpenCLParamType {
+  ValidKernelParam,
+  PtrPtrKernelParam,
+  PtrKernelParam,
+  InvalidKernelParam,
+  RecordKernelParam
+};
+
+static OpenCLParamType getOpenCLKernelParameterType(QualType PT) {
+  if (PT->isPointerType()) {
+    QualType PointeeType = PT->getPointeeType();
+    return PointeeType->isPointerType() ? PtrPtrKernelParam : PtrKernelParam;
+  }
+
+  // TODO: Forbid the other integer types (size_t, ptrdiff_t...) when they can
+  // be used as builtin types.
+
+  if (PT->isImageType())
+    return PtrKernelParam;
+
+  if (PT->isBooleanType())
+    return InvalidKernelParam;
+
+  if (PT->isEventT())
+    return InvalidKernelParam;
+
+  if (PT->isHalfType())
+    return InvalidKernelParam;
+
+  if (PT->isRecordType())
+    return RecordKernelParam;
+
+  return ValidKernelParam;
+}
+
+static void checkIsValidOpenCLKernelParameter(
+  Sema &S,
+  Declarator &D,
+  ParmVarDecl *Param,
+  llvm::SmallPtrSet<const Type *, 16> &ValidTypes) {
+  QualType PT = Param->getType();
+
+  // Cache the valid types we encounter to avoid rechecking structs that are
+  // used again
+  if (ValidTypes.count(PT.getTypePtr()))
+    return;
+
+  switch (getOpenCLKernelParameterType(PT)) {
+  case PtrPtrKernelParam:
+    // OpenCL v1.2 s6.9.a:
+    // A kernel function argument cannot be declared as a
+    // pointer to a pointer type.
+    S.Diag(Param->getLocation(), diag::err_opencl_ptrptr_kernel_param);
+    D.setInvalidType();
+    return;
+
+    // OpenCL v1.2 s6.9.k:
+    // Arguments to kernel functions in a program cannot be declared with the
+    // built-in scalar types bool, half, size_t, ptrdiff_t, intptr_t, and
+    // uintptr_t or a struct and/or union that contain fields declared to be
+    // one of these built-in scalar types.
+
+  case InvalidKernelParam:
+    // OpenCL v1.2 s6.8 n:
+    // A kernel function argument cannot be declared
+    // of event_t type.
+    S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT;
+    D.setInvalidType();
+    return;
+
+  case PtrKernelParam:
+  case ValidKernelParam:
+    ValidTypes.insert(PT.getTypePtr());
+    return;
+
+  case RecordKernelParam:
+    break;
+  }
+
+  // Track nested structs we will inspect
+  SmallVector<const Decl *, 4> VisitStack;
+
+  // Track where we are in the nested structs. Items will migrate from
+  // VisitStack to HistoryStack as we do the DFS for bad field.
+  SmallVector<const FieldDecl *, 4> HistoryStack;
+  HistoryStack.push_back((const FieldDecl *) 0);
+
+  const RecordDecl *PD = PT->castAs<RecordType>()->getDecl();
+  VisitStack.push_back(PD);
+
+  assert(VisitStack.back() && "First decl null?");
+
+  do {
+    const Decl *Next = VisitStack.pop_back_val();
+    if (!Next) {
+      assert(!HistoryStack.empty());
+      // Found a marker, we have gone up a level
+      if (const FieldDecl *Hist = HistoryStack.pop_back_val())
+        ValidTypes.insert(Hist->getType().getTypePtr());
+
+      continue;
+    }
+
+    // Adds everything except the original parameter declaration (which is not a
+    // field itself) to the history stack.
+    const RecordDecl *RD;
+    if (const FieldDecl *Field = dyn_cast<FieldDecl>(Next)) {
+      HistoryStack.push_back(Field);
+      RD = Field->getType()->castAs<RecordType>()->getDecl();
+    } else {
+      RD = cast<RecordDecl>(Next);
+    }
+
+    // Add a null marker so we know when we've gone back up a level
+    VisitStack.push_back((const Decl *) 0);
+
+    for (RecordDecl::field_iterator I = RD->field_begin(),
+           E = RD->field_end(); I != E; ++I) {
+      const FieldDecl *FD = *I;
+      QualType QT = FD->getType();
+
+      if (ValidTypes.count(QT.getTypePtr()))
+        continue;
+
+      OpenCLParamType ParamType = getOpenCLKernelParameterType(QT);
+      if (ParamType == ValidKernelParam)
+        continue;
+
+      if (ParamType == RecordKernelParam) {
+        VisitStack.push_back(FD);
+        continue;
+      }
+
+      // OpenCL v1.2 s6.9.p:
+      // Arguments to kernel functions that are declared to be a struct or union
+      // do not allow OpenCL objects to be passed as elements of the struct or
+      // union.
+      if (ParamType == PtrKernelParam || ParamType == PtrPtrKernelParam) {
+        S.Diag(Param->getLocation(),
+               diag::err_record_with_pointers_kernel_param)
+          << PT->isUnionType()
+          << PT;
+      } else {
+        S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT;
+      }
+
+      S.Diag(PD->getLocation(), diag::note_within_field_of_type)
+        << PD->getDeclName();
+
+      // We have an error, now let's go back up through history and show where
+      // the offending field came from
+      for (ArrayRef<const FieldDecl *>::const_iterator I = HistoryStack.begin() + 1,
+             E = HistoryStack.end(); I != E; ++I) {
+        const FieldDecl *OuterField = *I;
+        S.Diag(OuterField->getLocation(), diag::note_within_field_of_type)
+          << OuterField->getType();
+      }
+
+      S.Diag(FD->getLocation(), diag::note_illegal_field_declared_here)
+        << QT->isPointerType()
+        << QT;
+      D.setInvalidType();
+      return;
+    }
+  } while (!VisitStack.empty());
+}
+
 NamedDecl*
 Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
                               TypeSourceInfo *TInfo, LookupResult &Previous,
@@ -6803,27 +6970,12 @@
            diag::err_expected_kernel_void_return_type);
       D.setInvalidType();
     }
-    
+
+    llvm::SmallPtrSet<const Type *, 16> ValidTypes;
     for (FunctionDecl::param_iterator PI = NewFD->param_begin(),
          PE = NewFD->param_end(); PI != PE; ++PI) {
       ParmVarDecl *Param = *PI;
-      QualType PT = Param->getType();
-
-      // OpenCL v1.2 s6.9.a:
-      // A kernel function argument cannot be declared as a
-      // pointer to a pointer type.
-      if (PT->isPointerType() && PT->getPointeeType()->isPointerType()) {
-        Diag(Param->getLocation(), diag::err_opencl_ptrptr_kernel_arg);
-        D.setInvalidType();
-      }
-
-      // OpenCL v1.2 s6.8 n:
-      // A kernel function argument cannot be declared
-      // of event_t type.
-      if (PT->isEventT()) {
-        Diag(Param->getLocation(), diag::err_event_t_kernel_arg);
-        D.setInvalidType();
-      }
+      checkIsValidOpenCLKernelParameter(*this, D, Param, ValidTypes);
     }
   }