Check in support for OpenCL conditional operator on vector types.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@114371 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td
index 9508c91..5cb242f 100644
--- a/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2740,6 +2740,8 @@
   "used type %0 where arithmetic or pointer type is required">;
 def ext_typecheck_cond_one_void : Extension<
   "C99 forbids conditional expressions with only one void side">;
+def err_typecheck_cond_expect_scalar_or_vector : Error<
+  "used type %0 where arithmetic, pointer, or vector type is required">;
 def err_typecheck_cast_to_incomplete : Error<
   "cast to incomplete type %0">;
 def ext_typecheck_cast_nonscalar : Extension<
diff --git a/lib/CodeGen/CGExprScalar.cpp b/lib/CodeGen/CGExprScalar.cpp
index bacb564..73e94d1 100644
--- a/lib/CodeGen/CGExprScalar.cpp
+++ b/lib/CodeGen/CGExprScalar.cpp
@@ -2169,7 +2169,52 @@
       return Visit(Live);
   }
 
+  // OpenCL: If the condition is a vector, we can treat this condition like
+  // the select function.
+  if (CGF.getContext().getLangOptions().OpenCL 
+      && E->getCond()->getType()->isVectorType()) {
+    llvm::Value *CondV = CGF.EmitScalarExpr(E->getCond());
+    llvm::Value *LHS = Visit(E->getLHS());
+    llvm::Value *RHS = Visit(E->getRHS());
+    
+    const llvm::Type *condType = ConvertType(E->getCond()->getType());
+    const llvm::VectorType *vecTy = cast<llvm::VectorType>(condType);
+    
+    unsigned numElem = vecTy->getNumElements();      
+    const llvm::Type *elemType = vecTy->getElementType();
+    
+    std::vector<llvm::Constant*> Zvals;
+    for (unsigned i = 0; i < numElem; ++i)
+      Zvals.push_back(llvm::ConstantInt::get(elemType,0));
 
+    llvm::Value *zeroVec = llvm::ConstantVector::get(Zvals);    
+    llvm::Value *TestMSB = Builder.CreateICmpSLT(CondV, zeroVec);
+    llvm::Value *tmp = Builder.CreateSExt(TestMSB, 
+                                          llvm::VectorType::get(elemType,
+                                                                numElem),         
+                                          "sext");
+    llvm::Value *tmp2 = Builder.CreateNot(tmp);
+    
+    // Cast float to int to perform ANDs if necessary.
+    llvm::Value *RHSTmp = RHS;
+    llvm::Value *LHSTmp = LHS;
+    bool wasCast = false;
+    const llvm::VectorType *rhsVTy = cast<llvm::VectorType>(RHS->getType());
+    if (rhsVTy->getElementType()->isFloatTy()) {
+      RHSTmp = Builder.CreateBitCast(RHS, tmp2->getType());
+      LHSTmp = Builder.CreateBitCast(LHS, tmp->getType());
+      wasCast = true;
+    }
+    
+    llvm::Value *tmp3 = Builder.CreateAnd(RHSTmp, tmp2);
+    llvm::Value *tmp4 = Builder.CreateAnd(LHSTmp, tmp);
+    llvm::Value *tmp5 = Builder.CreateOr(tmp3, tmp4, "cond");
+    if (wasCast)
+      tmp5 = Builder.CreateBitCast(tmp5, RHS->getType());
+
+    return tmp5;
+  }
+  
   // If this is a really simple expression (like x ? 4 : 5), emit this as a
   // select instead of as control flow.  We can only do this if it is cheap and
   // safe to evaluate the LHS and RHS unconditionally.
diff --git a/lib/Sema/SemaExpr.cpp b/lib/Sema/SemaExpr.cpp
index 2db253a..79b2727 100644
--- a/lib/Sema/SemaExpr.cpp
+++ b/lib/Sema/SemaExpr.cpp
@@ -4222,15 +4222,47 @@
 
   // first, check the condition.
   if (!CondTy->isScalarType()) { // C99 6.5.15p2
-    Diag(Cond->getLocStart(), diag::err_typecheck_cond_expect_scalar)
-      << CondTy;
-    return QualType();
+    // OpenCL: Sec 6.3.i says the condition is allowed to be a vector or scalar.
+    // Throw an error if its not either.
+    if (getLangOptions().OpenCL) {
+      if (!CondTy->isVectorType()) {
+        Diag(Cond->getLocStart(), 
+             diag::err_typecheck_cond_expect_scalar_or_vector)
+          << CondTy;
+        return QualType();
+      }
+    }
+    else {
+      Diag(Cond->getLocStart(), diag::err_typecheck_cond_expect_scalar)
+        << CondTy;
+      return QualType();
+    }
   }
 
   // Now check the two expressions.
   if (LHSTy->isVectorType() || RHSTy->isVectorType())
     return CheckVectorOperands(QuestionLoc, LHS, RHS);
 
+  // OpenCL: If the condition is a vector, and both operands are scalar,
+  // attempt to implicity convert them to the vector type to act like the
+  // built in select.
+  if (getLangOptions().OpenCL && CondTy->isVectorType()) {
+    // Both operands should be of scalar type.
+    if (!LHSTy->isScalarType()) {
+      Diag(LHS->getLocStart(), diag::err_typecheck_cond_expect_scalar)
+        << CondTy;
+      return QualType();
+    }
+    if (!RHSTy->isScalarType()) {
+      Diag(RHS->getLocStart(), diag::err_typecheck_cond_expect_scalar)
+        << CondTy;
+      return QualType();
+    }
+    // Implicity convert these scalars to the type of the condition.
+    ImpCastExprToType(LHS, CondTy, CK_IntegralCast);
+    ImpCastExprToType(RHS, CondTy, CK_IntegralCast);
+  }
+  
   // If both operands have arithmetic type, do the usual arithmetic conversions
   // to find a common type: C99 6.5.15p3,5.
   if (LHSTy->isArithmeticType() && RHSTy->isArithmeticType()) {
diff --git a/test/Sema/opencl-cond.c b/test/Sema/opencl-cond.c
new file mode 100644
index 0000000..d654a15
--- /dev/null
+++ b/test/Sema/opencl-cond.c
@@ -0,0 +1,5 @@
+// RUN: %clang_cc1 %s -x cl -verify -pedantic -fsyntax-only
+
+typedef __attribute__((ext_vector_type(4))) float float4;
+
+float4 foo(float4 a, float4 b, float4 c, float4 d) { return a < b ? c : d; }