[c++20] P1330R0: permit simple-assignments that change the active member
of a union within constant expression evaluation.

llvm-svn: 361329
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 5084564..5ec2883 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -290,6 +290,27 @@
       }
     }
 
+    void truncate(ASTContext &Ctx, APValue::LValueBase Base,
+                  unsigned NewLength) {
+      if (Invalid)
+        return;
+
+      assert(Base && "cannot truncate path for null pointer");
+      assert(NewLength <= Entries.size() && "not a truncation");
+
+      if (NewLength == Entries.size())
+        return;
+      Entries.resize(NewLength);
+
+      bool IsArray = false;
+      bool FirstIsUnsizedArray = false;
+      MostDerivedPathLength = findMostDerivedSubobject(
+          Ctx, Base, Entries, MostDerivedArraySize, MostDerivedType, IsArray,
+          FirstIsUnsizedArray);
+      MostDerivedIsArrayElement = IsArray;
+      FirstEntryIsAnUnsizedArray = FirstIsUnsizedArray;
+    }
+
     void setInvalid() {
       Invalid = true;
       Entries.clear();
@@ -2024,10 +2045,13 @@
 static bool
 CheckConstantExpression(EvalInfo &Info, SourceLocation DiagLoc, QualType Type,
                         const APValue &Value,
-                        Expr::ConstExprUsage Usage = Expr::EvaluateForCodeGen) {
+                        Expr::ConstExprUsage Usage = Expr::EvaluateForCodeGen,
+                        SourceLocation SubobjectLoc = SourceLocation()) {
   if (!Value.hasValue()) {
     Info.FFDiag(DiagLoc, diag::note_constexpr_uninitialized)
       << true << Type;
+    if (SubobjectLoc.isValid())
+      Info.Note(SubobjectLoc, diag::note_constexpr_subobject_declared_here);
     return false;
   }
 
@@ -2043,18 +2067,20 @@
     QualType EltTy = Type->castAsArrayTypeUnsafe()->getElementType();
     for (unsigned I = 0, N = Value.getArrayInitializedElts(); I != N; ++I) {
       if (!CheckConstantExpression(Info, DiagLoc, EltTy,
-                                   Value.getArrayInitializedElt(I), Usage))
+                                   Value.getArrayInitializedElt(I), Usage,
+                                   SubobjectLoc))
         return false;
     }
     if (!Value.hasArrayFiller())
       return true;
     return CheckConstantExpression(Info, DiagLoc, EltTy, Value.getArrayFiller(),
-                                   Usage);
+                                   Usage, SubobjectLoc);
   }
   if (Value.isUnion() && Value.getUnionField()) {
     return CheckConstantExpression(Info, DiagLoc,
                                    Value.getUnionField()->getType(),
-                                   Value.getUnionValue(), Usage);
+                                   Value.getUnionValue(), Usage,
+                                   Value.getUnionField()->getLocation());
   }
   if (Value.isStruct()) {
     RecordDecl *RD = Type->castAs<RecordType>()->getDecl();
@@ -2062,7 +2088,8 @@
       unsigned BaseIndex = 0;
       for (const CXXBaseSpecifier &BS : CD->bases()) {
         if (!CheckConstantExpression(Info, DiagLoc, BS.getType(),
-                                     Value.getStructBase(BaseIndex), Usage))
+                                     Value.getStructBase(BaseIndex), Usage,
+                                     BS.getBeginLoc()))
           return false;
         ++BaseIndex;
       }
@@ -2073,7 +2100,7 @@
 
       if (!CheckConstantExpression(Info, DiagLoc, I->getType(),
                                    Value.getStructField(I->getFieldIndex()),
-                                   Usage))
+                                   Usage, I->getLocation()))
         return false;
     }
   }
@@ -2972,7 +2999,8 @@
 
   // Walk the designator's path to find the subobject.
   for (unsigned I = 0, N = Sub.Entries.size(); /**/; ++I) {
-    if (!O->hasValue()) {
+    // Reading an indeterminate value is undefined, but assigning over one is OK.
+    if (O->isAbsent() || (O->isIndeterminate() && handler.AccessKind != AK_Assign)) {
       if (!Info.checkingPotentialConstantExpression())
         Info.FFDiag(E, diag::note_constexpr_access_uninit)
             << handler.AccessKind << O->isIndeterminate();
@@ -4888,6 +4916,159 @@
   return RuntimeCheckFailed(&Paths);
 }
 
+namespace {
+struct StartLifetimeOfUnionMemberHandler {
+  const FieldDecl *Field;
+
+  static const AccessKinds AccessKind = AK_Assign;
+
+  APValue getDefaultInitValue(QualType SubobjType) {
+    if (auto *RD = SubobjType->getAsCXXRecordDecl()) {
+      if (RD->isUnion())
+        return APValue((const FieldDecl*)nullptr);
+
+      APValue Struct(APValue::UninitStruct(), RD->getNumBases(),
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      unsigned Index = 0;
+      for (CXXRecordDecl::base_class_const_iterator I = RD->bases_begin(),
+             End = RD->bases_end(); I != End; ++I, ++Index)
+        Struct.getStructBase(Index) = getDefaultInitValue(I->getType());
+
+      for (const auto *I : RD->fields()) {
+        if (I->isUnnamedBitfield())
+          continue;
+        Struct.getStructField(I->getFieldIndex()) =
+            getDefaultInitValue(I->getType());
+      }
+      return Struct;
+    }
+
+    if (auto *AT = dyn_cast_or_null<ConstantArrayType>(
+            SubobjType->getAsArrayTypeUnsafe())) {
+      APValue Array(APValue::UninitArray(), 0, AT->getSize().getZExtValue());
+      if (Array.hasArrayFiller())
+        Array.getArrayFiller() = getDefaultInitValue(AT->getElementType());
+      return Array;
+    }
+
+    return APValue::IndeterminateValue();
+  }
+
+  typedef bool result_type;
+  bool failed() { return false; }
+  bool found(APValue &Subobj, QualType SubobjType) {
+    // We are supposed to perform no initialization but begin the lifetime of
+    // the object. We interpret that as meaning to do what default
+    // initialization of the object would do if all constructors involved were
+    // trivial:
+    //  * All base, non-variant member, and array element subobjects' lifetimes
+    //    begin
+    //  * No variant members' lifetimes begin
+    //  * All scalar subobjects whose lifetimes begin have indeterminate values
+    assert(SubobjType->isUnionType());
+    if (!declaresSameEntity(Subobj.getUnionField(), Field))
+      Subobj.setUnion(Field, getDefaultInitValue(Field->getType()));
+    return true;
+  }
+  bool found(APSInt &Value, QualType SubobjType) {
+    llvm_unreachable("wrong value kind for union object");
+  }
+  bool found(APFloat &Value, QualType SubobjType) {
+    llvm_unreachable("wrong value kind for union object");
+  }
+};
+} // end anonymous namespace
+
+const AccessKinds StartLifetimeOfUnionMemberHandler::AccessKind;
+
+/// Handle a builtin simple-assignment or a call to a trivial assignment
+/// operator whose left-hand side might involve a union member access. If it
+/// does, implicitly start the lifetime of any accessed union elements per
+/// C++20 [class.union]5.
+static bool HandleUnionActiveMemberChange(EvalInfo &Info, const Expr *LHSExpr,
+                                          const LValue &LHS) {
+  if (LHS.InvalidBase || LHS.Designator.Invalid)
+    return false;
+
+  llvm::SmallVector<std::pair<unsigned, const FieldDecl*>, 4> UnionPathLengths;
+  // C++ [class.union]p5:
+  //   define the set S(E) of subexpressions of E as follows:
+  const Expr *E = LHSExpr;
+  unsigned PathLength = LHS.Designator.Entries.size();
+  while (E) {
+    //   -- If E is of the form A.B, S(E) contains the elements of S(A)...
+    if (auto *ME = dyn_cast<MemberExpr>(E)) {
+      auto *FD = dyn_cast<FieldDecl>(ME->getMemberDecl());
+      if (!FD)
+        break;
+
+      //    ... and also contains A.B if B names a union member
+      if (FD->getParent()->isUnion())
+        UnionPathLengths.push_back({PathLength - 1, FD});
+
+      E = ME->getBase();
+      --PathLength;
+      assert(declaresSameEntity(FD,
+                                LHS.Designator.Entries[PathLength]
+                                    .getAsBaseOrMember().getPointer()));
+
+      //   -- If E is of the form A[B] and is interpreted as a built-in array
+      //      subscripting operator, S(E) is [S(the array operand, if any)].
+    } else if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E)) {
+      // Step over an ArrayToPointerDecay implicit cast.
+      auto *Base = ASE->getBase()->IgnoreImplicit();
+      if (!Base->getType()->isArrayType())
+        break;
+
+      E = Base;
+      --PathLength;
+
+    } else if (auto *ICE = dyn_cast<ImplicitCastExpr>(E)) {
+      // Step over a derived-to-base conversion.
+      if (ICE->getCastKind() == CK_NoOp)
+        continue;
+      if (ICE->getCastKind() != CK_DerivedToBase &&
+          ICE->getCastKind() != CK_UncheckedDerivedToBase)
+        break;
+      for (const CXXBaseSpecifier *Elt : ICE->path()) {
+        --PathLength;
+        assert(declaresSameEntity(Elt->getType()->getAsCXXRecordDecl(),
+                                  LHS.Designator.Entries[PathLength]
+                                      .getAsBaseOrMember().getPointer()));
+      }
+      E = ICE->getSubExpr();
+
+    //   -- Otherwise, S(E) is empty.
+    } else {
+      break;
+    }
+  }
+
+  // Common case: no unions' lifetimes are started.
+  if (UnionPathLengths.empty())
+    return true;
+
+  //   if modification of X [would access an inactive union member], an object
+  //   of the type of X is implicitly created
+  CompleteObject Obj =
+      findCompleteObject(Info, LHSExpr, AK_Assign, LHS, LHSExpr->getType());
+  if (!Obj)
+    return false;
+  for (std::pair<unsigned, const FieldDecl *> LengthAndField :
+           llvm::reverse(UnionPathLengths)) {
+    // Form a designator for the union object.
+    SubobjectDesignator D = LHS.Designator;
+    D.truncate(Info.Ctx, LHS.Base, LengthAndField.first);
+
+    StartLifetimeOfUnionMemberHandler StartLifetime{LengthAndField.second};
+    if (!findSubobject(Info, LHSExpr, Obj, D, StartLifetime))
+      return false;
+  }
+
+  return true;
+}
+
 /// Determine if a class has any fields that might need to be copied by a
 /// trivial copy or move operation.
 static bool hasFields(const CXXRecordDecl *RD) {
@@ -4958,6 +5139,9 @@
     if (!handleLValueToRValueConversion(Info, Args[0], Args[0]->getType(),
                                         RHS, RHSValue))
       return false;
+    if (Info.getLangOpts().CPlusPlus2a && MD->isTrivial() &&
+        !HandleUnionActiveMemberChange(Info, Args[0], *This))
+      return false;
     if (!handleAssignment(Info, Args[0], *This, MD->getThisType(),
                           RHSValue))
       return false;
@@ -6183,6 +6367,10 @@
   if (!Evaluate(NewVal, this->Info, E->getRHS()))
     return false;
 
+  if (Info.getLangOpts().CPlusPlus2a &&
+      !HandleUnionActiveMemberChange(Info, E->getLHS(), Result))
+    return false;
+
   return handleAssignment(this->Info, E, Result, E->getLHS()->getType(),
                           NewVal);
 }