[OPENMP] Additional sema analysis for 'omp atomic[ update]'.
Adds additional semantic analysis + generation of helper expressions for proper codegen.

llvm-svn: 232164
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 9b66d2f..7042399 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -3236,6 +3236,211 @@
   return OMPOrderedDirective::Create(Context, StartLoc, EndLoc, AStmt);
 }
 
+namespace {
+/// \brief Helper class for checking expression in 'omp atomic [update]'
+/// construct.
+class OpenMPAtomicUpdateChecker {
+  /// \brief Error results for atomic update expressions.
+  enum ExprAnalysisErrorCode {
+    /// \brief A statement is not an expression statement.
+    NotAnExpression,
+    /// \brief Expression is not builtin binary or unary operation.
+    NotABinaryOrUnaryExpression,
+    /// \brief Unary operation is not post-/pre- increment/decrement operation.
+    NotAnUnaryIncDecExpression,
+    /// \brief An expression is not of scalar type.
+    NotAScalarType,
+    /// \brief A binary operation is not an assignment operation.
+    NotAnAssignmentOp,
+    /// \brief RHS part of the binary operation is not a binary expression.
+    NotABinaryExpression,
+    /// \brief RHS part is not additive/multiplicative/shift/biwise binary
+    /// expression.
+    NotABinaryOperator,
+    /// \brief RHS binary operation does not have reference to the updated LHS
+    /// part.
+    NotAnUpdateExpression,
+    /// \brief No errors is found.
+    NoError
+  };
+  /// \brief Reference to Sema.
+  Sema &SemaRef;
+  /// \brief A location for note diagnostics (when error is found).
+  SourceLocation NoteLoc;
+  /// \brief Atomic operation supposed to be performed on source expression.
+  BinaryOperatorKind OpKind;
+  /// \brief 'x' lvalue part of the source atomic expression.
+  Expr *X;
+  /// \brief 'x' rvalue part of the source atomic expression, used in the right
+  /// hand side of the expression. We need this to properly generate RHS part of
+  /// the source expression (x = x'rval' binop expr or x = expr binop x'rval').
+  Expr *XRVal;
+  /// \brief 'expr' rvalue part of the source atomic expression.
+  Expr *E;
+
+public:
+  OpenMPAtomicUpdateChecker(Sema &SemaRef)
+      : SemaRef(SemaRef), OpKind(BO_PtrMemD), X(nullptr), XRVal(nullptr),
+        E(nullptr) {}
+  /// \brief Check specified statement that it is suitable for 'atomic update'
+  /// constructs and extract 'x', 'expr' and Operation from the original
+  /// expression.
+  /// \param DiagId Diagnostic which should be emitted if error is found.
+  /// \param NoteId Diagnostic note for the main error message.
+  /// \return true if statement is not an update expression, false otherwise.
+  bool checkStatement(Stmt *S, unsigned DiagId, unsigned NoteId);
+  /// \brief Return the 'x' lvalue part of the source atomic expression.
+  Expr *getX() const { return X; }
+  /// \brief Return the 'x' rvalue part of the source atomic expression, used in
+  /// the RHS part of the source expression.
+  Expr *getXRVal() const { return XRVal; }
+  /// \brief Return the 'expr' rvalue part of the source atomic expression.
+  Expr *getExpr() const { return E; }
+  /// \brief Return required atomic operation.
+  BinaryOperatorKind getOpKind() const {return OpKind;}
+private:
+  bool checkBinaryOperation(BinaryOperator *AtomicBinOp, unsigned DiagId,
+                            unsigned NoteId);
+};
+} // namespace
+
+bool OpenMPAtomicUpdateChecker::checkBinaryOperation(
+    BinaryOperator *AtomicBinOp, unsigned DiagId, unsigned NoteId) {
+  ExprAnalysisErrorCode ErrorFound = NoError;
+  SourceLocation ErrorLoc, NoteLoc;
+  SourceRange ErrorRange, NoteRange;
+  // Allowed constructs are:
+  //  x = x binop expr;
+  //  x = expr binop x;
+  if (AtomicBinOp->getOpcode() == BO_Assign) {
+    X = AtomicBinOp->getLHS();
+    if (auto *AtomicInnerBinOp = dyn_cast<BinaryOperator>(
+            AtomicBinOp->getRHS()->IgnoreParenImpCasts())) {
+      if (AtomicInnerBinOp->isMultiplicativeOp() ||
+          AtomicInnerBinOp->isAdditiveOp() || AtomicInnerBinOp->isShiftOp() ||
+          AtomicInnerBinOp->isBitwiseOp()) {
+        OpKind = AtomicInnerBinOp->getOpcode();
+        auto *LHS = AtomicInnerBinOp->getLHS();
+        auto *RHS = AtomicInnerBinOp->getRHS();
+        llvm::FoldingSetNodeID XId, LHSId, RHSId;
+        X->IgnoreParenImpCasts()->Profile(XId, SemaRef.getASTContext(),
+                                          /*Canonical=*/true);
+        LHS->IgnoreParenImpCasts()->Profile(LHSId, SemaRef.getASTContext(),
+                                            /*Canonical=*/true);
+        RHS->IgnoreParenImpCasts()->Profile(RHSId, SemaRef.getASTContext(),
+                                            /*Canonical=*/true);
+        if (XId == LHSId) {
+          E = RHS;
+          XRVal = LHS;
+        } else if (XId == RHSId) {
+          E = LHS;
+          XRVal = RHS;
+        } else {
+          ErrorLoc = AtomicInnerBinOp->getExprLoc();
+          ErrorRange = AtomicInnerBinOp->getSourceRange();
+          NoteLoc = X->getExprLoc();
+          NoteRange = X->getSourceRange();
+          ErrorFound = NotAnUpdateExpression;
+        }
+      } else {
+        ErrorLoc = AtomicInnerBinOp->getExprLoc();
+        ErrorRange = AtomicInnerBinOp->getSourceRange();
+        NoteLoc = AtomicInnerBinOp->getOperatorLoc();
+        NoteRange = SourceRange(NoteLoc, NoteLoc);
+        ErrorFound = NotABinaryOperator;
+      }
+    } else {
+      NoteLoc = ErrorLoc = AtomicBinOp->getRHS()->getExprLoc();
+      NoteRange = ErrorRange = AtomicBinOp->getRHS()->getSourceRange();
+      ErrorFound = NotABinaryExpression;
+    }
+  } else {
+    ErrorLoc = AtomicBinOp->getExprLoc();
+    ErrorRange = AtomicBinOp->getSourceRange();
+    NoteLoc = AtomicBinOp->getOperatorLoc();
+    NoteRange = SourceRange(NoteLoc, NoteLoc);
+    ErrorFound = NotAnAssignmentOp;
+  }
+  if (ErrorFound != NoError) {
+    SemaRef.Diag(ErrorLoc, DiagId) << ErrorRange;
+    SemaRef.Diag(NoteLoc, NoteId) << ErrorFound << NoteRange;
+    return true;
+  } else if (SemaRef.CurContext->isDependentContext())
+    E = X = XRVal = nullptr;
+  return false;
+}
+
+bool OpenMPAtomicUpdateChecker::checkStatement(Stmt *S, unsigned DiagId,
+                                               unsigned NoteId) {
+  ExprAnalysisErrorCode ErrorFound = NoError;
+  SourceLocation ErrorLoc, NoteLoc;
+  SourceRange ErrorRange, NoteRange;
+  // Allowed constructs are:
+  //  x++;
+  //  x--;
+  //  ++x;
+  //  --x;
+  //  x binop= expr;
+  //  x = x binop expr;
+  //  x = expr binop x;
+  if (auto *AtomicBody = dyn_cast<Expr>(S)) {
+    AtomicBody = AtomicBody->IgnoreParenImpCasts();
+    if (AtomicBody->getType()->isScalarType() ||
+        AtomicBody->isInstantiationDependent()) {
+      if (auto *AtomicCompAssignOp = dyn_cast<CompoundAssignOperator>(
+              AtomicBody->IgnoreParenImpCasts())) {
+        // Check for Compound Assignment Operation
+        OpKind = BinaryOperator::getOpForCompoundAssignment(
+            AtomicCompAssignOp->getOpcode());
+        X = AtomicCompAssignOp->getLHS();
+        XRVal = SemaRef.PerformImplicitConversion(
+                            X, AtomicCompAssignOp->getComputationLHSType(),
+                            Sema::AA_Casting, /*AllowExplicit=*/true).get();
+        E = AtomicCompAssignOp->getRHS();
+      } else if (auto *AtomicBinOp = dyn_cast<BinaryOperator>(
+                     AtomicBody->IgnoreParenImpCasts())) {
+        // Check for Binary Operation
+        return checkBinaryOperation(AtomicBinOp, DiagId, NoteId);
+      } else if (auto *AtomicUnaryOp =
+                     // Check for Binary Operation
+                 dyn_cast<UnaryOperator>(AtomicBody->IgnoreParenImpCasts())) {
+        // Check for Unary Operation
+        if (AtomicUnaryOp->isIncrementDecrementOp()) {
+          OpKind = AtomicUnaryOp->isIncrementOp() ? BO_Add : BO_Sub;
+          XRVal = X = AtomicUnaryOp->getSubExpr();
+          E = SemaRef.ActOnIntegerConstant(AtomicUnaryOp->getOperatorLoc(), 1)
+                  .get();
+        } else {
+          ErrorFound = NotAnUnaryIncDecExpression;
+          ErrorLoc = AtomicUnaryOp->getExprLoc();
+          ErrorRange = AtomicUnaryOp->getSourceRange();
+          NoteLoc = AtomicUnaryOp->getOperatorLoc();
+          NoteRange = SourceRange(NoteLoc, NoteLoc);
+        }
+      } else {
+        ErrorFound = NotABinaryOrUnaryExpression;
+        NoteLoc = ErrorLoc = AtomicBody->getExprLoc();
+        NoteRange = ErrorRange = AtomicBody->getSourceRange();
+      }
+    } else {
+      ErrorFound = NotAScalarType;
+      NoteLoc = ErrorLoc = AtomicBody->getLocStart();
+      NoteRange = ErrorRange = SourceRange(NoteLoc, NoteLoc);
+    }
+  } else {
+    ErrorFound = NotAnExpression;
+    NoteLoc = ErrorLoc = S->getLocStart();
+    NoteRange = ErrorRange = SourceRange(NoteLoc, NoteLoc);
+  }
+  if (ErrorFound != NoError) {
+    SemaRef.Diag(ErrorLoc, DiagId) << ErrorRange;
+    SemaRef.Diag(NoteLoc, NoteId) << ErrorFound << NoteRange;
+    return true;
+  } else if (SemaRef.CurContext->isDependentContext())
+    E = X = XRVal = nullptr;
+  return false;
+}
+
 StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
                                             Stmt *AStmt,
                                             SourceLocation StartLoc,
@@ -3270,7 +3475,9 @@
   if (auto *EWC = dyn_cast<ExprWithCleanups>(Body))
     Body = EWC->getSubExpr();
 
+  BinaryOperatorKind OpKind = BO_PtrMemD;
   Expr *X = nullptr;
+  Expr *XRVal = nullptr;
   Expr *V = nullptr;
   Expr *E = nullptr;
   // OpenMP [2.12.6, atomic Construct]
@@ -3412,11 +3619,26 @@
     } else if (CurContext->isDependentContext())
       E = X = nullptr;
   } else if (AtomicKind == OMPC_update || AtomicKind == OMPC_unknown) {
-    if (!isa<Expr>(Body)) {
-      Diag(Body->getLocStart(),
-           diag::err_omp_atomic_update_not_expression_statement)
-          << (AtomicKind == OMPC_update);
+    // If clause is update:
+    //  x++;
+    //  x--;
+    //  ++x;
+    //  --x;
+    //  x binop= expr;
+    //  x = x binop expr;
+    //  x = expr binop x;
+    OpenMPAtomicUpdateChecker Checker(*this);
+    if (Checker.checkStatement(
+            Body, (AtomicKind == OMPC_update)
+                      ? diag::err_omp_atomic_update_not_expression_statement
+                      : diag::err_omp_atomic_not_expression_statement,
+            diag::note_omp_atomic_update))
       return StmtError();
+    if (!CurContext->isDependentContext()) {
+      E = Checker.getExpr();
+      X = Checker.getX();
+      XRVal = Checker.getXRVal();
+      OpKind = Checker.getOpKind();
     }
   } else if (AtomicKind == OMPC_capture) {
     if (isa<Expr>(Body) && !isa<BinaryOperator>(Body)) {
@@ -3433,7 +3655,7 @@
   getCurFunction()->setHasBranchProtectedScope();
 
   return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    X, V, E);
+                                    OpKind, X, XRVal, V, E);
 }
 
 StmtResult Sema::ActOnOpenMPTargetDirective(ArrayRef<OMPClause *> Clauses,