C++1y constant expression evaluation: support for compound assignments on integers.

llvm-svn: 181287
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index a3315e3..5695af2 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -1299,6 +1299,128 @@
   return false;
 }
 
+/// Perform the given integer operation, which is known to need at most BitWidth
+/// bits, and check for overflow in the original type (if that type was not an
+/// unsigned type).
+template<typename Operation>
+static APSInt CheckedIntArithmetic(EvalInfo &Info, const Expr *E,
+                                   const APSInt &LHS, const APSInt &RHS,
+                                   unsigned BitWidth, Operation Op) {
+  if (LHS.isUnsigned())
+    return Op(LHS, RHS);
+
+  APSInt Value(Op(LHS.extend(BitWidth), RHS.extend(BitWidth)), false);
+  APSInt Result = Value.trunc(LHS.getBitWidth());
+  if (Result.extend(BitWidth) != Value) {
+    if (Info.getIntOverflowCheckMode())
+      Info.Ctx.getDiagnostics().Report(E->getExprLoc(),
+        diag::warn_integer_constant_overflow)
+          << Result.toString(10) << E->getType();
+    else
+      HandleOverflow(Info, E, Value, E->getType());
+  }
+  return Result;
+}
+
+/// Perform the given binary integer operation.
+static bool handleIntIntBinOp(EvalInfo &Info, const Expr *E, const APSInt &LHS,
+                              BinaryOperatorKind Opcode, APSInt RHS,
+                              APSInt &Result) {
+  switch (Opcode) {
+  default:
+    Info.Diag(E);
+    return false;
+  case BO_Mul:
+    Result = CheckedIntArithmetic(Info, E, LHS, RHS, LHS.getBitWidth() * 2,
+                                  std::multiplies<APSInt>());
+    return true;
+  case BO_Add:
+    Result = CheckedIntArithmetic(Info, E, LHS, RHS, LHS.getBitWidth() + 1,
+                                  std::plus<APSInt>());
+    return true;
+  case BO_Sub:
+    Result = CheckedIntArithmetic(Info, E, LHS, RHS, LHS.getBitWidth() + 1,
+                                  std::minus<APSInt>());
+    return true;
+  case BO_And: Result = LHS & RHS; return true;
+  case BO_Xor: Result = LHS ^ RHS; return true;
+  case BO_Or:  Result = LHS | RHS; return true;
+  case BO_Div:
+  case BO_Rem:
+    if (RHS == 0) {
+      Info.Diag(E, diag::note_expr_divide_by_zero);
+      return false;
+    }
+    // Check for overflow case: INT_MIN / -1 or INT_MIN % -1.
+    if (RHS.isNegative() && RHS.isAllOnesValue() &&
+        LHS.isSigned() && LHS.isMinSignedValue())
+      HandleOverflow(Info, E, -LHS.extend(LHS.getBitWidth() + 1), E->getType());
+    Result = (Opcode == BO_Rem ? LHS % RHS : LHS / RHS);
+    return true;
+  case BO_Shl: {
+    if (Info.getLangOpts().OpenCL)
+      // OpenCL 6.3j: shift values are effectively % word size of LHS.
+      RHS &= APSInt(llvm::APInt(RHS.getBitWidth(),
+                    static_cast<uint64_t>(LHS.getBitWidth() - 1)),
+                    RHS.isUnsigned());
+    else if (RHS.isSigned() && RHS.isNegative()) {
+      // During constant-folding, a negative shift is an opposite shift. Such
+      // a shift is not a constant expression.
+      Info.CCEDiag(E, diag::note_constexpr_negative_shift) << RHS;
+      RHS = -RHS;
+      goto shift_right;
+    }
+  shift_left:
+    // C++11 [expr.shift]p1: Shift width must be less than the bit width of
+    // the shifted type.
+    unsigned SA = (unsigned) RHS.getLimitedValue(LHS.getBitWidth()-1);
+    if (SA != RHS) {
+      Info.CCEDiag(E, diag::note_constexpr_large_shift)
+        << RHS << E->getType() << LHS.getBitWidth();
+    } else if (LHS.isSigned()) {
+      // C++11 [expr.shift]p2: A signed left shift must have a non-negative
+      // operand, and must not overflow the corresponding unsigned type.
+      if (LHS.isNegative())
+        Info.CCEDiag(E, diag::note_constexpr_lshift_of_negative) << LHS;
+      else if (LHS.countLeadingZeros() < SA)
+        Info.CCEDiag(E, diag::note_constexpr_lshift_discards);
+    }
+    Result = LHS << SA;
+    return true;
+  }
+  case BO_Shr: {
+    if (Info.getLangOpts().OpenCL)
+      // OpenCL 6.3j: shift values are effectively % word size of LHS.
+      RHS &= APSInt(llvm::APInt(RHS.getBitWidth(),
+                    static_cast<uint64_t>(LHS.getBitWidth() - 1)),
+                    RHS.isUnsigned());
+    else if (RHS.isSigned() && RHS.isNegative()) {
+      // During constant-folding, a negative shift is an opposite shift. Such a
+      // shift is not a constant expression.
+      Info.CCEDiag(E, diag::note_constexpr_negative_shift) << RHS;
+      RHS = -RHS;
+      goto shift_left;
+    }
+  shift_right:
+    // C++11 [expr.shift]p1: Shift width must be less than the bit width of the
+    // shifted type.
+    unsigned SA = (unsigned) RHS.getLimitedValue(LHS.getBitWidth()-1);
+    if (SA != RHS)
+      Info.CCEDiag(E, diag::note_constexpr_large_shift)
+        << RHS << E->getType() << LHS.getBitWidth();
+    Result = LHS >> SA;
+    return true;
+  }
+
+  case BO_LT: Result = LHS < RHS; return true;
+  case BO_GT: Result = LHS > RHS; return true;
+  case BO_LE: Result = LHS <= RHS; return true;
+  case BO_GE: Result = LHS >= RHS; return true;
+  case BO_EQ: Result = LHS == RHS; return true;
+  case BO_NE: Result = LHS != RHS; return true;
+  }
+}
+
 /// Cast an lvalue referring to a base subobject to a derived class, by
 /// truncating the lvalue's path to the given length.
 static bool CastToDerivedClass(EvalInfo &Info, const Expr *E, LValue &Result,
@@ -2159,6 +2281,116 @@
 }
 
 namespace {
+struct CompoundAssignSubobjectHandler {
+  EvalInfo &Info;
+  const Expr *E;
+  QualType PromotedLHSType;
+  BinaryOperatorKind Opcode;
+  const APValue &RHS;
+
+  static const AccessKinds AccessKind = AK_Assign;
+
+  typedef bool result_type;
+
+  bool checkConst(QualType QT) {
+    // Assigning to a const object has undefined behavior.
+    if (QT.isConstQualified()) {
+      Info.Diag(E, diag::note_constexpr_modify_const_type) << QT;
+      return false;
+    }
+    return true;
+  }
+
+  bool failed() { return false; }
+  bool found(APValue &Subobj, QualType SubobjType) {
+    switch (Subobj.getKind()) {
+    case APValue::Int:
+      return found(Subobj.getInt(), SubobjType);
+    case APValue::Float:
+      return found(Subobj.getFloat(), SubobjType);
+    case APValue::ComplexInt:
+    case APValue::ComplexFloat:
+      // FIXME: Implement complex compound assignment.
+      Info.Diag(E);
+      return false;
+    case APValue::LValue:
+      return foundPointer(Subobj, SubobjType);
+    default:
+      // FIXME: can this happen?
+      Info.Diag(E);
+      return false;
+    }
+  }
+  bool found(APSInt &Value, QualType SubobjType) {
+    if (!checkConst(SubobjType))
+      return false;
+
+    if (!SubobjType->isIntegerType() || !RHS.isInt()) {
+      // We don't support compound assignment on integer-cast-to-pointer
+      // values.
+      Info.Diag(E);
+      return false;
+    }
+
+    APSInt LHS = HandleIntToIntCast(Info, E, PromotedLHSType,
+                                    SubobjType, Value);
+    if (!handleIntIntBinOp(Info, E, LHS, Opcode, RHS.getInt(), LHS))
+      return false;
+    Value = HandleIntToIntCast(Info, E, SubobjType, PromotedLHSType, LHS);
+    return true;
+  }
+  bool found(APFloat &Value, QualType SubobjType) {
+    if (!checkConst(SubobjType))
+      return false;
+
+    // FIXME: Implement.
+    Info.Diag(E);
+    return false;
+  }
+  bool foundPointer(APValue &Subobj, QualType SubobjType) {
+    if (!checkConst(SubobjType))
+      return false;
+
+    QualType PointeeType;
+    if (const PointerType *PT = SubobjType->getAs<PointerType>())
+      PointeeType = PT->getPointeeType();
+    else {
+      Info.Diag(E);
+      return false;
+    }
+
+    // FIXME: Implement.
+    Info.Diag(E);
+    return false;
+  }
+  bool foundString(APValue &Subobj, QualType SubobjType, uint64_t Character) {
+    llvm_unreachable("shouldn't encounter string elements here");
+  }
+};
+} // end anonymous namespace
+
+const AccessKinds CompoundAssignSubobjectHandler::AccessKind;
+
+/// Perform a compound assignment of LVal <op>= RVal.
+static bool handleCompoundAssignment(
+    EvalInfo &Info, const Expr *E,
+    const LValue &LVal, QualType LValType, QualType PromotedLValType,
+    BinaryOperatorKind Opcode, const APValue &RVal) {
+  if (LVal.Designator.Invalid)
+    return false;
+
+  if (!Info.getLangOpts().CPlusPlus1y) {
+    Info.Diag(E);
+    return false;
+  }
+
+  CompleteObject Obj = findCompleteObject(Info, E, AK_Assign, LVal, LValType);
+  CompoundAssignSubobjectHandler Handler = { Info, E, PromotedLValType, Opcode,
+                                             RVal };
+  return Obj && findSubobject(Info, E, Obj, LVal.Designator, Handler);
+}
+
+namespace {
 struct IncDecSubobjectHandler {
   EvalInfo &Info;
   const Expr *E;
@@ -3655,14 +3887,10 @@
   if (!Evaluate(RHS, this->Info, CAO->getRHS()))
     return false;
 
-  // FIXME:
-  //return handleCompoundAssignment(
-  //    this->Info, CAO,
-  //    Result, CAO->getLHS()->getType(), CAO->getComputationLHSType(),
-  //    RHS, CAO->getRHS()->getType(),
-  //    CAO->getOpForCompoundAssignment(CAO->getOpcode()),
-  //    CAO->getComputationResultType());
-  return Error(CAO);
+  return handleCompoundAssignment(
+      this->Info, CAO,
+      Result, CAO->getLHS()->getType(), CAO->getComputationLHSType(),
+      CAO->getOpForCompoundAssignment(CAO->getOpcode()), RHS);
 }
 
 bool LValueExprEvaluator::VisitBinAssign(const BinaryOperator *E) {
@@ -5170,29 +5398,6 @@
          A.getLValueCallIndex() == B.getLValueCallIndex();
 }
 
-/// Perform the given integer operation, which is known to need at most BitWidth
-/// bits, and check for overflow in the original type (if that type was not an
-/// unsigned type).
-template<typename Operation>
-static APSInt CheckedIntArithmetic(EvalInfo &Info, const Expr *E,
-                                   const APSInt &LHS, const APSInt &RHS,
-                                   unsigned BitWidth, Operation Op) {
-  if (LHS.isUnsigned())
-    return Op(LHS, RHS);
-
-  APSInt Value(Op(LHS.extend(BitWidth), RHS.extend(BitWidth)), false);
-  APSInt Result = Value.trunc(LHS.getBitWidth());
-  if (Result.extend(BitWidth) != Value) {
-    if (Info.getIntOverflowCheckMode())
-      Info.Ctx.getDiagnostics().Report(E->getExprLoc(),
-        diag::warn_integer_constant_overflow)
-          << Result.toString(10) << E->getType();
-    else
-      HandleOverflow(Info, E, Value, E->getType());
-  }
-  return Result;
-}
-
 namespace {
 
 /// \brief Data recursive integer evaluator of certain binary operators.
@@ -5437,108 +5642,20 @@
     Result = APValue(LHSAddrExpr, RHSAddrExpr);
     return true;
   }
-  
-  // All the following cases expect both operands to be an integer
+
+  // All the remaining cases expect both operands to be an integer
   if (!LHSVal.isInt() || !RHSVal.isInt())
     return Error(E);
-  
-  const APSInt &LHS = LHSVal.getInt();
-  APSInt RHS = RHSVal.getInt();
-  
-  switch (E->getOpcode()) {
-    default:
-      return Error(E);
-    case BO_Mul:
-      return Success(CheckedIntArithmetic(Info, E, LHS, RHS,
-                                          LHS.getBitWidth() * 2,
-                                          std::multiplies<APSInt>()), E,
-                     Result);
-    case BO_Add:
-      return Success(CheckedIntArithmetic(Info, E, LHS, RHS,
-                                          LHS.getBitWidth() + 1,
-                                          std::plus<APSInt>()), E, Result);
-    case BO_Sub:
-      return Success(CheckedIntArithmetic(Info, E, LHS, RHS,
-                                          LHS.getBitWidth() + 1,
-                                          std::minus<APSInt>()), E, Result);
-    case BO_And: return Success(LHS & RHS, E, Result);
-    case BO_Xor: return Success(LHS ^ RHS, E, Result);
-    case BO_Or:  return Success(LHS | RHS, E, Result);
-    case BO_Div:
-    case BO_Rem:
-      if (RHS == 0)
-        return Error(E, diag::note_expr_divide_by_zero);
-      // Check for overflow case: INT_MIN / -1 or INT_MIN % -1. The latter is
-      // not actually undefined behavior in C++11 due to a language defect.
-      if (RHS.isNegative() && RHS.isAllOnesValue() &&
-          LHS.isSigned() && LHS.isMinSignedValue())
-        HandleOverflow(Info, E, -LHS.extend(LHS.getBitWidth() + 1), E->getType());
-      return Success(E->getOpcode() == BO_Rem ? LHS % RHS : LHS / RHS, E,
-                     Result);
-    case BO_Shl: {
-      if (Info.getLangOpts().OpenCL)
-        // OpenCL 6.3j: shift values are effectively % word size of LHS.
-        RHS &= APSInt(llvm::APInt(RHS.getBitWidth(),
-                      static_cast<uint64_t>(LHS.getBitWidth() - 1)),
-                      RHS.isUnsigned());
-      else if (RHS.isSigned() && RHS.isNegative()) {
-        // During constant-folding, a negative shift is an opposite shift. Such
-        // a shift is not a constant expression.
-        CCEDiag(E, diag::note_constexpr_negative_shift) << RHS;
-        RHS = -RHS;
-        goto shift_right;
-      }
-      
-    shift_left:
-      // C++11 [expr.shift]p1: Shift width must be less than the bit width of
-      // the shifted type.
-      unsigned SA = (unsigned) RHS.getLimitedValue(LHS.getBitWidth()-1);
-      if (SA != RHS) {
-        CCEDiag(E, diag::note_constexpr_large_shift)
-        << RHS << E->getType() << LHS.getBitWidth();
-      } else if (LHS.isSigned()) {
-        // C++11 [expr.shift]p2: A signed left shift must have a non-negative
-        // operand, and must not overflow the corresponding unsigned type.
-        if (LHS.isNegative())
-          CCEDiag(E, diag::note_constexpr_lshift_of_negative) << LHS;
-        else if (LHS.countLeadingZeros() < SA)
-          CCEDiag(E, diag::note_constexpr_lshift_discards);
-      }
-      
-      return Success(LHS << SA, E, Result);
-    }
-    case BO_Shr: {
-      if (Info.getLangOpts().OpenCL)
-        // OpenCL 6.3j: shift values are effectively % word size of LHS.
-        RHS &= APSInt(llvm::APInt(RHS.getBitWidth(),
-                      static_cast<uint64_t>(LHS.getBitWidth() - 1)),
-                      RHS.isUnsigned());
-      else if (RHS.isSigned() && RHS.isNegative()) {
-        // During constant-folding, a negative shift is an opposite shift. Such a
-        // shift is not a constant expression.
-        CCEDiag(E, diag::note_constexpr_negative_shift) << RHS;
-        RHS = -RHS;
-        goto shift_left;
-      }
-      
-    shift_right:
-      // C++11 [expr.shift]p1: Shift width must be less than the bit width of the
-      // shifted type.
-      unsigned SA = (unsigned) RHS.getLimitedValue(LHS.getBitWidth()-1);
-      if (SA != RHS)
-        CCEDiag(E, diag::note_constexpr_large_shift)
-        << RHS << E->getType() << LHS.getBitWidth();
-      
-      return Success(LHS >> SA, E, Result);
-    }
-      
-    case BO_LT: return Success(LHS < RHS, E, Result);
-    case BO_GT: return Success(LHS > RHS, E, Result);
-    case BO_LE: return Success(LHS <= RHS, E, Result);
-    case BO_GE: return Success(LHS >= RHS, E, Result);
-    case BO_EQ: return Success(LHS == RHS, E, Result);
-    case BO_NE: return Success(LHS != RHS, E, Result);
-  }
+
+  // Set up the width and signedness manually, in case it can't be deduced
+  // from the operation we're performing.
+  // FIXME: Don't do this in the cases where we can deduce it.
+  APSInt Value(Info.Ctx.getIntWidth(E->getType()),
+               E->getType()->isUnsignedIntegerOrEnumerationType());
+  if (!handleIntIntBinOp(Info, E, LHSVal.getInt(), E->getOpcode(),
+                         RHSVal.getInt(), Value))
+    return false;
+  return Success(Value, E, Result);
 }
 
 void DataRecursiveIntBinOpEvaluator::process(EvalResult &Result) {