Teach emit-llvm for scalars to properly handle compound assignment 
operators in all their glory :)


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@41373 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/CodeGen/CGExprScalar.cpp b/CodeGen/CGExprScalar.cpp
index 9c2aa36..bdaf179 100644
--- a/CodeGen/CGExprScalar.cpp
+++ b/CodeGen/CGExprScalar.cpp
@@ -28,6 +28,7 @@
 struct BinOpInfo {
   Value *LHS;
   Value *RHS;
+  QualType Ty;  // Computation Type.
   const BinaryOperator *E;
 };
 
@@ -152,18 +153,6 @@
   }
   
   // Binary Operators.
-  BinOpInfo EmitBinOps(const BinaryOperator *E);
-  Value *VisitBinMul(const BinaryOperator *E) { return EmitMul(EmitBinOps(E)); }
-  Value *VisitBinDiv(const BinaryOperator *E) { return EmitDiv(EmitBinOps(E)); }
-  Value *VisitBinRem(const BinaryOperator *E) { return EmitRem(EmitBinOps(E)); }
-  Value *VisitBinAdd(const BinaryOperator *E) { return EmitAdd(EmitBinOps(E)); }
-  Value *VisitBinSub(const BinaryOperator *E) { return EmitSub(EmitBinOps(E)); }
-  Value *VisitBinShl(const BinaryOperator *E) { return EmitShl(EmitBinOps(E)); }
-  Value *VisitBinShr(const BinaryOperator *E) { return EmitShr(EmitBinOps(E)); }
-  Value *VisitBinAnd(const BinaryOperator *E) { return EmitAnd(EmitBinOps(E)); }
-  Value *VisitBinXor(const BinaryOperator *E) { return EmitXor(EmitBinOps(E)); }
-  Value *VisitBinOr (const BinaryOperator *E) { return EmitOr (EmitBinOps(E)); }
-
   Value *EmitMul(const BinOpInfo &Ops) {
     return Builder.CreateMul(Ops.LHS, Ops.RHS, "mul");
   }
@@ -183,6 +172,34 @@
     return Builder.CreateOr(Ops.LHS, Ops.RHS, "or");
   }
 
+  BinOpInfo EmitBinOps(const BinaryOperator *E);
+  Value *EmitCompoundAssign(const BinaryOperator *E,
+                            Value *(ScalarExprEmitter::*F)(const BinOpInfo &));
+
+  // Binary operators and binary compound assignment operators.
+#define HANDLEBINOP(OP) \
+  Value *VisitBin ## OP(const BinaryOperator *E) {                 \
+    return Emit ## OP(EmitBinOps(E));                              \
+  }                                                                \
+  Value *VisitBin ## OP ## Assign(const BinaryOperator *E) {       \
+    return EmitCompoundAssign(E, &ScalarExprEmitter::Emit ## OP);  \
+  }
+  HANDLEBINOP(Mul);
+  HANDLEBINOP(Div);
+  HANDLEBINOP(Rem);
+  HANDLEBINOP(Add);
+  //         (Sub) - Sub is handled specially below for ptr-ptr subtract.
+  HANDLEBINOP(Shl);
+  HANDLEBINOP(Shr);
+  HANDLEBINOP(And);
+  HANDLEBINOP(Xor);
+  HANDLEBINOP(Or);
+#undef HANDLEBINOP
+  Value *VisitBinSub(const BinaryOperator *E);
+  Value *VisitBinSubAssign(const BinaryOperator *E) {
+    return EmitCompoundAssign(E, &ScalarExprEmitter::EmitSub);
+  }
+  
   // Comparisons.
   Value *EmitCompare(const BinaryOperator *E, unsigned UICmpOpc,
                      unsigned SICmpOpc, unsigned FCmpOpc);
@@ -368,14 +385,59 @@
   BinOpInfo Result;
   Result.LHS = Visit(E->getLHS());
   Result.RHS = Visit(E->getRHS());
+  Result.Ty  = E->getType();
   Result.E = E;
   return Result;
 }
 
+Value *ScalarExprEmitter::EmitCompoundAssign(const BinaryOperator *E,
+                      Value *(ScalarExprEmitter::*Func)(const BinOpInfo &)) {
+  QualType LHSTy = E->getLHS()->getType(), RHSTy = E->getRHS()->getType();
+
+  BinOpInfo OpInfo;
+
+  // Load the LHS and RHS operands.
+  LValue LHSLV = EmitLValue(E->getLHS());
+  OpInfo.LHS = EmitLoadOfLValue(LHSLV, LHSTy);
+  
+  // FIXME: It is possible for the RHS to be complex.
+  OpInfo.RHS = Visit(E->getRHS());
+  
+  // Convert the LHS/RHS values to the computation type.
+  const CompoundAssignOperator *CAO = cast<CompoundAssignOperator>(E);
+  QualType ComputeType = CAO->getComputationType();
+  
+  // FIXME: it's possible for the computation type to be complex if the RHS
+  // is complex.  Handle this!
+  OpInfo.LHS = CGF.EmitConversion(RValue::get(OpInfo.LHS), LHSTy,
+                                  ComputeType).getVal();
+  
+  // Do not merge types for -= where the LHS is a pointer.
+  if (E->getOpcode() != BinaryOperator::SubAssign &&
+      E->getLHS()->getType()->isPointerType()) {
+    OpInfo.RHS = CGF.EmitConversion(RValue::get(OpInfo.RHS), RHSTy,
+                                    ComputeType).getVal();
+  }
+  OpInfo.Ty = ComputeType;
+  OpInfo.E = E;
+  
+  // Expand the binary operator.
+  Value *Result = (this->*Func)(OpInfo);
+  
+  // Truncate the result back to the LHS type.
+  Result = CGF.EmitConversion(RValue::get(Result), ComputeType, LHSTy).getVal();
+  
+  // Store the result value into the LHS lvalue.
+  CGF.EmitStoreThroughLValue(RValue::get(Result), LHSLV, E->getType());
+
+  return Result;
+}
+
+
 Value *ScalarExprEmitter::EmitDiv(const BinOpInfo &Ops) {
   if (Ops.LHS->getType()->isFloatingPoint())
     return Builder.CreateFDiv(Ops.LHS, Ops.RHS, "div");
-  else if (Ops.E->getType()->isUnsignedIntegerType())
+  else if (Ops.Ty->isUnsignedIntegerType())
     return Builder.CreateUDiv(Ops.LHS, Ops.RHS, "div");
   else
     return Builder.CreateSDiv(Ops.LHS, Ops.RHS, "div");
@@ -383,7 +445,7 @@
 
 Value *ScalarExprEmitter::EmitRem(const BinOpInfo &Ops) {
   // Rem in C can't be a floating point type: C99 6.5.5p2.
-  if (Ops.E->getType()->isUnsignedIntegerType())
+  if (Ops.Ty->isUnsignedIntegerType())
     return Builder.CreateURem(Ops.LHS, Ops.RHS, "rem");
   else
     return Builder.CreateSRem(Ops.LHS, Ops.RHS, "rem");
@@ -391,8 +453,10 @@
 
 
 Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &Ops) {
-  if (!Ops.E->getType()->isPointerType())
+  if (!Ops.Ty->isPointerType())
     return Builder.CreateAdd(Ops.LHS, Ops.RHS, "add");
+  
+  // FIXME: What about a pointer to a VLA?
   if (isa<llvm::PointerType>(Ops.LHS->getType())) // pointer + int
     return Builder.CreateGEP(Ops.LHS, Ops.RHS, "add.ptr");
   // int + pointer
@@ -403,30 +467,36 @@
   if (!isa<llvm::PointerType>(Ops.LHS->getType()))
     return Builder.CreateSub(Ops.LHS, Ops.RHS, "sub");
   
-  // FIXME: This isn't right for -=.
-  QualType LHSTy = Ops.E->getLHS()->getType();
-  QualType RHSTy = Ops.E->getRHS()->getType();
-  
-  const PointerType *RHSPtrType = dyn_cast<PointerType>(RHSTy.getTypePtr());
-  if (RHSPtrType == 0) {   // pointer - int
-    Value *NegatedRHS = Builder.CreateNeg(Ops.RHS, "sub.ptr.neg");
-    return Builder.CreateGEP(Ops.LHS, NegatedRHS, "sub.ptr");
-  }
+  // pointer - int
+  assert(!isa<llvm::PointerType>(Ops.RHS->getType()) &&
+         "ptr-ptr shouldn't get here");
+  // FIXME: The pointer could point to a VLA.
+  Value *NegatedRHS = Builder.CreateNeg(Ops.RHS, "sub.ptr.neg");
+  return Builder.CreateGEP(Ops.LHS, NegatedRHS, "sub.ptr");
+}
+
+Value *ScalarExprEmitter::VisitBinSub(const BinaryOperator *E) {
+  // "X - Y" is different from "X -= Y" in one case: when Y is a pointer.  In
+  // the compound assignment case it is invalid, so just handle it here.
+  if (!E->getRHS()->getType()->isPointerType())
+    return EmitSub(EmitBinOps(E));
   
   // pointer - pointer
-  const PointerType *LHSPtrType = cast<PointerType>(LHSTy.getTypePtr());
+  Value *LHS = Visit(E->getLHS());
+  Value *RHS = Visit(E->getRHS());
+  
+  const PointerType *LHSPtrType = E->getLHS()->getType()->getAsPointerType();
+  assert(LHSPtrType == E->getRHS()->getType()->getAsPointerType() &&
+         "Can't subtract different pointer types");
+  
   QualType LHSElementType = LHSPtrType->getPointeeType();
-  assert(LHSElementType == RHSPtrType->getPointeeType() &&
-         "can't subtract pointers with differing element types");
   uint64_t ElementSize = CGF.getContext().getTypeSize(LHSElementType,
                                                       SourceLocation()) / 8;
-  const llvm::Type *ResultType = ConvertType(Ops.E->getType());
-  Value *CastLHS = Builder.CreatePtrToInt(Ops.LHS, ResultType,
-                                                "sub.ptr.lhs.cast");
-  Value *CastRHS = Builder.CreatePtrToInt(Ops.RHS, ResultType,
-                                                "sub.ptr.rhs.cast");
-  Value *BytesBetween = Builder.CreateSub(CastLHS, CastRHS,
-                                                "sub.ptr.sub");
+  
+  const llvm::Type *ResultType = ConvertType(E->getType());
+  LHS = Builder.CreatePtrToInt(LHS, ResultType, "sub.ptr.lhs.cast");
+  RHS = Builder.CreatePtrToInt(RHS, ResultType, "sub.ptr.rhs.cast");
+  Value *BytesBetween = Builder.CreateSub(LHS, RHS, "sub.ptr.sub");
   
   // HACK: LLVM doesn't have an divide instruction that 'knows' there is no
   // remainder.  As such, we handle common power-of-two cases here to generate
@@ -436,11 +506,13 @@
     llvm::ConstantInt::get(ResultType, llvm::Log2_64(ElementSize));
     return Builder.CreateAShr(BytesBetween, ShAmt, "sub.ptr.shr");
   }
+  
   // Otherwise, do a full sdiv.
   Value *BytesPerElt = llvm::ConstantInt::get(ResultType, ElementSize);
   return Builder.CreateSDiv(BytesBetween, BytesPerElt, "sub.ptr.div");
 }
 
+
 Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
   // LLVM requires the LHS and RHS to be the same type: promote or truncate the
   // RHS to the same size as the LHS.
@@ -458,7 +530,7 @@
   if (Ops.LHS->getType() != RHS->getType())
     RHS = Builder.CreateIntCast(RHS, Ops.LHS->getType(), false, "sh_prom");
   
-  if (Ops.E->getType()->isUnsignedIntegerType())
+  if (Ops.Ty->isUnsignedIntegerType())
     return Builder.CreateLShr(Ops.LHS, RHS, "shr");
   return Builder.CreateAShr(Ops.LHS, RHS, "shr");
 }