Allow implicit casts during arithmetic for OCUVector operations
Add codegen support and test for said casts.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@45443 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/CodeGen/CGExprScalar.cpp b/CodeGen/CGExprScalar.cpp
index 8d5f814..0fc1d0f 100644
--- a/CodeGen/CGExprScalar.cpp
+++ b/CodeGen/CGExprScalar.cpp
@@ -360,6 +360,15 @@
     return Builder.CreatePtrToInt(Src, DstTy, "conv");
   }
   
+  // A scalar source can be splatted to a vector of the same element type
+  if (isa<llvm::VectorType>(DstTy) && !isa<VectorType>(SrcType)) {
+    const llvm::VectorType *VT = cast<llvm::VectorType>(DstTy);
+    assert((VT->getElementType() == Src->getType()) &&
+           "Vector element type must match scalar type to splat.");
+    return CGF.EmitVector(&Src, DstType->getAsVectorType()->getNumElements(), 
+                          true);
+  }
+
   if (isa<llvm::VectorType>(Src->getType()) ||
       isa<llvm::VectorType>(DstTy)) {
     return Builder.CreateBitCast(Src, DstTy, "conv");
@@ -1049,14 +1058,15 @@
 }
 
 llvm::Value *CodeGenFunction::EmitVector(llvm::Value * const *Vals, 
-                                         unsigned NumVals)
+                                         unsigned NumVals, bool isSplat)
 {
   llvm::Value *Vec
   = llvm::UndefValue::get(llvm::VectorType::get(Vals[0]->getType(), NumVals));
   
   for (unsigned i = 0, e = NumVals ; i != e; ++i) {
+    llvm::Value *Val = isSplat ? Vals[0] : Vals[i];
     llvm::Value *Idx = llvm::ConstantInt::get(llvm::Type::Int32Ty, i);
-    Vec = Builder.CreateInsertElement(Vec, Vals[i], Idx, "tmp");
+    Vec = Builder.CreateInsertElement(Vec, Val, Idx, "tmp");
   }
   
   return Vec;  
diff --git a/CodeGen/CodeGenFunction.h b/CodeGen/CodeGenFunction.h
index a263ca0..216b7cd 100644
--- a/CodeGen/CodeGenFunction.h
+++ b/CodeGen/CodeGenFunction.h
@@ -395,7 +395,8 @@
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   
   llvm::Value *EmitShuffleVector(llvm::Value* V1, llvm::Value *V2, ...);
-  llvm::Value *EmitVector(llvm::Value * const *Vals, unsigned NumVals);
+  llvm::Value *EmitVector(llvm::Value * const *Vals, unsigned NumVals,
+                          bool isSplat = false);
   
   llvm::Value *EmitObjCStringLiteral(const ObjCStringLiteral *E);
 
diff --git a/Sema/SemaExpr.cpp b/Sema/SemaExpr.cpp
index 014e457..20118c2 100644
--- a/Sema/SemaExpr.cpp
+++ b/Sema/SemaExpr.cpp
@@ -1125,15 +1125,15 @@
   }
   else if (lhsType->isArithmeticType() && rhsType->isArithmeticType()) {
     if (lhsType->isVectorType() || rhsType->isVectorType()) {
+      // For OCUVector, allow vector splats; float -> <n x float>
+      if (const OCUVectorType *LV = lhsType->getAsOCUVectorType()) {
+        if (LV->getElementType().getTypePtr() == rhsType.getTypePtr())
+          return Compatible;
+      }
       if (!getLangOptions().LaxVectorConversions) {
         if (lhsType.getCanonicalType() != rhsType.getCanonicalType())
           return Incompatible;
       } else {
-        // For OCUVector, allow vector splats; float -> <n x float>
-        if (const OCUVectorType *LV = lhsType->getAsOCUVectorType()) {
-          if (LV->getElementType().getTypePtr() == rhsType.getTypePtr())
-            return Compatible;
-        }
         if (lhsType->isVectorType() && rhsType->isVectorType()) {
           // If LHS and RHS are both integer or both floating point types, and
           // the total vector length is the same, allow the conversion.  This is
@@ -1218,6 +1218,27 @@
   // make sure the vector types are identical. 
   if (lhsType == rhsType)
     return lhsType;
+
+  // if the lhs is an ocu vector and the rhs is a scalar of the same type,
+  // promote the rhs to the vector type.
+  if (const OCUVectorType *V = lhsType->getAsOCUVectorType()) {
+    if (V->getElementType().getCanonicalType().getTypePtr()
+        == rhsType.getCanonicalType().getTypePtr()) {
+      promoteExprToType(rex, lhsType);
+      return lhsType;
+    }
+  }
+
+  // if the rhs is an ocu vector and the lhs is a scalar of the same type,
+  // promote the lhs to the vector type.
+  if (const OCUVectorType *V = rhsType->getAsOCUVectorType()) {
+    if (V->getElementType().getCanonicalType().getTypePtr()
+        == lhsType.getCanonicalType().getTypePtr()) {
+      promoteExprToType(lex, rhsType);
+      return rhsType;
+    }
+  }
+
   // You cannot convert between vector values of different size.
   Diag(loc, diag::err_typecheck_vector_not_convertable, 
        lex->getType().getAsString(), rex->getType().getAsString(),
diff --git a/test/CodeGen/ocu-vector.c b/test/CodeGen/ocu-vector.c
index 9e904f6..ee6e737 100644
--- a/test/CodeGen/ocu-vector.c
+++ b/test/CodeGen/ocu-vector.c
@@ -1,7 +1,6 @@
 // RUN: clang -emit-llvm %s
 
 typedef __attribute__(( ocu_vector_type(4) )) float float4;
-//typedef __attribute__(( ocu_vector_type(3) )) float float3;
 typedef __attribute__(( ocu_vector_type(2) )) float float2;
 
 
@@ -33,3 +32,16 @@
   float d = 4.0f;
   *out = ((float4) {a,b,c,d});
 }
+
+static void test5(float4 *out) {
+  float a;
+  float4 b;
+  
+  a = 1.0f;
+  b = a;
+  b = b * 5.0f;
+  b = 5.0f * b;
+  b *= a;
+  
+  *out = b;
+}