[Analyzer] Synthesize function body for std::call_once

Differential Revision: https://reviews.llvm.org/D37840

llvm-svn: 314571
diff --git a/clang/lib/Analysis/BodyFarm.cpp b/clang/lib/Analysis/BodyFarm.cpp
index 5912724..6ca758e 100644
--- a/clang/lib/Analysis/BodyFarm.cpp
+++ b/clang/lib/Analysis/BodyFarm.cpp
@@ -14,11 +14,18 @@
 
 #include "BodyFarm.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/CXXInheritance.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/Expr.h"
+#include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
+#include "clang/AST/NestedNameSpecifier.h"
 #include "clang/Analysis/CodeInjector.h"
+#include "clang/Basic/OperatorKinds.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "body-farm"
 
 using namespace clang;
 
@@ -55,7 +62,9 @@
   CompoundStmt *makeCompound(ArrayRef<Stmt*>);
   
   /// Create a new DeclRefExpr for the referenced variable.
-  DeclRefExpr *makeDeclRefExpr(const VarDecl *D);
+  DeclRefExpr *makeDeclRefExpr(const VarDecl *D,
+                               bool RefersToEnclosingVariableOrCapture = false,
+                               bool GetNonReferenceType = false);
   
   /// Create a new UnaryOperator representing a dereference.
   UnaryOperator *makeDereference(const Expr *Arg, QualType Ty);
@@ -66,9 +75,24 @@
   /// Create an implicit cast to a builtin boolean type.
   ImplicitCastExpr *makeIntegralCastToBoolean(const Expr *Arg);
   
-  // Create an implicit cast for lvalue-to-rvaluate conversions.
+  /// Create an implicit cast for lvalue-to-rvaluate conversions.
   ImplicitCastExpr *makeLvalueToRvalue(const Expr *Arg, QualType Ty);
   
+  /// Create an implicit cast for lvalue-to-rvaluate conversions.
+  ImplicitCastExpr *makeLvalueToRvalue(const Expr *Arg,
+                                       bool GetNonReferenceType = false);
+
+  /// Make RValue out of variable declaration, creating a temporary
+  /// DeclRefExpr in the process.
+  ImplicitCastExpr *
+  makeLvalueToRvalue(const VarDecl *Decl,
+                     bool RefersToEnclosingVariableOrCapture = false,
+                     bool GetNonReferenceType = false);
+
+  /// Create an implicit cast of the given type.
+  ImplicitCastExpr *makeImplicitCast(const Expr *Arg, QualType Ty,
+                                     CastKind CK = CK_LValueToRValue);
+
   /// Create an Objective-C bool literal.
   ObjCBoolLiteralExpr *makeObjCBool(bool Val);
 
@@ -78,6 +102,18 @@
   /// Create a Return statement.
   ReturnStmt *makeReturn(const Expr *RetVal);
   
+  /// Create an integer literal.
+  IntegerLiteral *makeIntegerLiteral(uint64_t value);
+
+  /// Create a member expression.
+  MemberExpr *makeMemberExpression(Expr *base, ValueDecl *MemberDecl,
+                                   bool IsArrow = false,
+                                   ExprValueKind ValueKind = VK_LValue);
+
+  /// Returns a *first* member field of a record declaration with a given name.
+  /// \return an nullptr if no member with such a name exists.
+  NamedDecl *findMemberField(const CXXRecordDecl *RD, StringRef Name);
+
 private:
   ASTContext &C;
 };
@@ -106,16 +142,16 @@
   return new (C) CompoundStmt(C, Stmts, SourceLocation(), SourceLocation());
 }
 
-DeclRefExpr *ASTMaker::makeDeclRefExpr(const VarDecl *D) {
-  DeclRefExpr *DR =
-    DeclRefExpr::Create(/* Ctx = */ C,
-                        /* QualifierLoc = */ NestedNameSpecifierLoc(),
-                        /* TemplateKWLoc = */ SourceLocation(),
-                        /* D = */ const_cast<VarDecl*>(D),
-                        /* RefersToEnclosingVariableOrCapture = */ false,
-                        /* NameLoc = */ SourceLocation(),
-                        /* T = */ D->getType(),
-                        /* VK = */ VK_LValue);
+DeclRefExpr *ASTMaker::makeDeclRefExpr(const VarDecl *D,
+                                       bool RefersToEnclosingVariableOrCapture,
+                                       bool GetNonReferenceType) {
+  auto Type = D->getType();
+  if (GetNonReferenceType)
+    Type = Type.getNonReferenceType();
+
+  DeclRefExpr *DR = DeclRefExpr::Create(
+      C, NestedNameSpecifierLoc(), SourceLocation(), const_cast<VarDecl *>(D),
+      RefersToEnclosingVariableOrCapture, SourceLocation(), Type, VK_LValue);
   return DR;
 }
 
@@ -125,8 +161,38 @@
 }
 
 ImplicitCastExpr *ASTMaker::makeLvalueToRvalue(const Expr *Arg, QualType Ty) {
-  return ImplicitCastExpr::Create(C, Ty, CK_LValueToRValue,
-                                  const_cast<Expr*>(Arg), nullptr, VK_RValue);
+  return makeImplicitCast(Arg, Ty, CK_LValueToRValue);
+}
+
+ImplicitCastExpr *ASTMaker::makeLvalueToRvalue(const Expr *Arg,
+                                               bool GetNonReferenceType) {
+
+  QualType Type = Arg->getType();
+  if (GetNonReferenceType)
+    Type = Type.getNonReferenceType();
+  return makeImplicitCast(Arg, Type, CK_LValueToRValue);
+}
+
+ImplicitCastExpr *
+ASTMaker::makeLvalueToRvalue(const VarDecl *Arg,
+                             bool RefersToEnclosingVariableOrCapture,
+                             bool GetNonReferenceType) {
+  auto Type = Arg->getType();
+  if (GetNonReferenceType)
+    Type = Type.getNonReferenceType();
+  return makeLvalueToRvalue(makeDeclRefExpr(Arg,
+                                            RefersToEnclosingVariableOrCapture,
+                                            GetNonReferenceType),
+                            Type);
+}
+
+ImplicitCastExpr *ASTMaker::makeImplicitCast(const Expr *Arg, QualType Ty,
+                                             CastKind CK) {
+  return ImplicitCastExpr::Create(C, Ty,
+                                  /* CastKind= */ CK,
+                                  /* Expr= */ const_cast<Expr *>(Arg),
+                                  /* CXXCastPath= */ nullptr,
+                                  /* ExprValueKind= */ VK_RValue);
 }
 
 Expr *ASTMaker::makeIntegralCast(const Expr *Arg, QualType Ty) {
@@ -161,12 +227,196 @@
                             nullptr);
 }
 
+IntegerLiteral *ASTMaker::makeIntegerLiteral(uint64_t value) {
+  return IntegerLiteral::Create(C,
+                                llvm::APInt(
+                                    /*numBits=*/C.getTypeSize(C.IntTy), value),
+                                /*QualType=*/C.IntTy, SourceLocation());
+}
+
+MemberExpr *ASTMaker::makeMemberExpression(Expr *base, ValueDecl *MemberDecl,
+                                           bool IsArrow,
+                                           ExprValueKind ValueKind) {
+
+  DeclAccessPair FoundDecl = DeclAccessPair::make(MemberDecl, AS_public);
+  return MemberExpr::Create(
+      C, base, IsArrow, SourceLocation(), NestedNameSpecifierLoc(),
+      SourceLocation(), MemberDecl, FoundDecl,
+      DeclarationNameInfo(MemberDecl->getDeclName(), SourceLocation()),
+      /* TemplateArgumentListInfo= */ nullptr, MemberDecl->getType(), ValueKind,
+      OK_Ordinary);
+}
+
+NamedDecl *ASTMaker::findMemberField(const CXXRecordDecl *RD, StringRef Name) {
+
+  CXXBasePaths Paths(
+      /* FindAmbiguities=*/false,
+      /* RecordPaths=*/false,
+      /* DetectVirtual= */ false);
+  const IdentifierInfo &II = C.Idents.get(Name);
+  DeclarationName DeclName = C.DeclarationNames.getIdentifier(&II);
+
+  DeclContextLookupResult Decls = RD->lookup(DeclName);
+  for (NamedDecl *FoundDecl : Decls)
+    if (!FoundDecl->getDeclContext()->isFunctionOrMethod())
+      return FoundDecl;
+
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // Creation functions for faux ASTs.
 //===----------------------------------------------------------------------===//
 
 typedef Stmt *(*FunctionFarmer)(ASTContext &C, const FunctionDecl *D);
 
+static CallExpr *
+create_call_once_funcptr_call(ASTContext &C, ASTMaker M,
+                              const ParmVarDecl *Callback,
+                              SmallVectorImpl<Expr *> &CallArgs) {
+
+  return new (C) CallExpr(
+      /*ASTContext=*/C,
+      /*StmtClass=*/M.makeLvalueToRvalue(/*Expr=*/Callback),
+      /*args=*/CallArgs,
+      /*QualType=*/C.VoidTy,
+      /*ExprValueType=*/VK_RValue,
+      /*SourceLocation=*/SourceLocation());
+}
+
+static CallExpr *
+create_call_once_lambda_call(ASTContext &C, ASTMaker M,
+                             const ParmVarDecl *Callback, QualType CallbackType,
+                             SmallVectorImpl<Expr *> &CallArgs) {
+
+  CXXRecordDecl *CallbackDecl = CallbackType->getAsCXXRecordDecl();
+
+  assert(CallbackDecl != nullptr);
+  assert(CallbackDecl->isLambda());
+  FunctionDecl *callOperatorDecl = CallbackDecl->getLambdaCallOperator();
+  assert(callOperatorDecl != nullptr);
+
+  DeclRefExpr *callOperatorDeclRef =
+      DeclRefExpr::Create(/* Ctx = */ C,
+                          /* QualifierLoc = */ NestedNameSpecifierLoc(),
+                          /* TemplateKWLoc = */ SourceLocation(),
+                          const_cast<FunctionDecl *>(callOperatorDecl),
+                          /* RefersToEnclosingVariableOrCapture= */ false,
+                          /* NameLoc = */ SourceLocation(),
+                          /* T = */ callOperatorDecl->getType(),
+                          /* VK = */ VK_LValue);
+
+  CallArgs.insert(
+      CallArgs.begin(),
+      M.makeDeclRefExpr(Callback,
+                        /* RefersToEnclosingVariableOrCapture= */ true,
+                        /* GetNonReferenceType= */ true));
+
+  return new (C)
+      CXXOperatorCallExpr(/*AstContext=*/C, OO_Call, callOperatorDeclRef,
+                          /*args=*/CallArgs,
+                          /*QualType=*/C.VoidTy,
+                          /*ExprValueType=*/VK_RValue,
+                          /*SourceLocation=*/SourceLocation(), FPOptions());
+}
+
+/// Create a fake body for std::call_once.
+/// Emulates the following function body:
+///
+/// \code
+/// typedef struct once_flag_s {
+///   unsigned long __state = 0;
+/// } once_flag;
+/// template<class Callable>
+/// void call_once(once_flag& o, Callable func) {
+///   if (!o.__state) {
+///     func();
+///   }
+///   o.__state = 1;
+/// }
+/// \endcode
+static Stmt *create_call_once(ASTContext &C, const FunctionDecl *D) {
+  DEBUG(llvm::dbgs() << "Generating body for call_once\n");
+
+  // We need at least two parameters.
+  if (D->param_size() < 2)
+    return nullptr;
+
+  ASTMaker M(C);
+
+  const ParmVarDecl *Flag = D->getParamDecl(0);
+  const ParmVarDecl *Callback = D->getParamDecl(1);
+  QualType CallbackType = Callback->getType().getNonReferenceType();
+
+  SmallVector<Expr *, 5> CallArgs;
+
+  // All arguments past first two ones are passed to the callback.
+  for (unsigned int i = 2; i < D->getNumParams(); i++)
+    CallArgs.push_back(M.makeLvalueToRvalue(D->getParamDecl(i)));
+
+  CallExpr *CallbackCall;
+  if (CallbackType->getAsCXXRecordDecl() &&
+      CallbackType->getAsCXXRecordDecl()->isLambda()) {
+
+    CallbackCall =
+        create_call_once_lambda_call(C, M, Callback, CallbackType, CallArgs);
+  } else {
+
+    // Function pointer case.
+    CallbackCall = create_call_once_funcptr_call(C, M, Callback, CallArgs);
+  }
+
+  QualType FlagType = Flag->getType().getNonReferenceType();
+  DeclRefExpr *FlagDecl =
+      M.makeDeclRefExpr(Flag,
+                        /* RefersToEnclosingVariableOrCapture=*/true,
+                        /* GetNonReferenceType=*/true);
+
+  CXXRecordDecl *FlagCXXDecl = FlagType->getAsCXXRecordDecl();
+
+  // Note: here we are assuming libc++ implementation of call_once,
+  // which has a struct with a field `__state_`.
+  // Body farming might not work for other `call_once` implementations.
+  NamedDecl *FoundDecl = M.findMemberField(FlagCXXDecl, "__state_");
+  ValueDecl *FieldDecl;
+  if (FoundDecl) {
+    FieldDecl = dyn_cast<ValueDecl>(FoundDecl);
+  } else {
+    DEBUG(llvm::dbgs() << "No field __state_ found on std::once_flag struct, "
+                       << "unable to synthesize call_once body, ignoring "
+                       << "the call.\n");
+    return nullptr;
+  }
+
+  MemberExpr *Deref = M.makeMemberExpression(FlagDecl, FieldDecl);
+  assert(Deref->isLValue());
+  QualType DerefType = Deref->getType();
+
+  // Negation predicate.
+  UnaryOperator *FlagCheck = new (C) UnaryOperator(
+      /* input= */
+      M.makeImplicitCast(M.makeLvalueToRvalue(Deref, DerefType), DerefType,
+                         CK_IntegralToBoolean),
+      /* opc= */ UO_LNot,
+      /* QualType= */ C.IntTy,
+      /* ExprValueKind= */ VK_RValue,
+      /* ExprObjectKind= */ OK_Ordinary, SourceLocation());
+
+  // Create assignment.
+  BinaryOperator *FlagAssignment = M.makeAssignment(
+      Deref, M.makeIntegralCast(M.makeIntegerLiteral(1), DerefType), DerefType);
+
+  IfStmt *Out = new (C)
+      IfStmt(C, SourceLocation(),
+             /* IsConstexpr= */ false,
+             /* init= */ nullptr,
+             /* var= */ nullptr,
+             /* cond= */ FlagCheck,
+             /* then= */ M.makeCompound({CallbackCall, FlagAssignment}));
+
+  return Out;
+}
+
 /// Create a fake body for dispatch_once.
 static Stmt *create_dispatch_once(ASTContext &C, const FunctionDecl *D) {
   // Check if we have at least two parameters.
@@ -202,15 +452,17 @@
   ASTMaker M(C);
   
   // (1) Create the call.
-  DeclRefExpr *DR = M.makeDeclRefExpr(Block);
-  ImplicitCastExpr *ICE = M.makeLvalueToRvalue(DR, Ty);
-  CallExpr *CE = new (C) CallExpr(C, ICE, None, C.VoidTy, VK_RValue,
-                                  SourceLocation());
+  CallExpr *CE = new (C) CallExpr(
+      /*ASTContext=*/C,
+      /*StmtClass=*/M.makeLvalueToRvalue(/*Expr=*/Block),
+      /*args=*/None,
+      /*QualType=*/C.VoidTy,
+      /*ExprValueType=*/VK_RValue,
+      /*SourceLocation=*/SourceLocation());
 
   // (2) Create the assignment to the predicate.
-  IntegerLiteral *IL =
-    IntegerLiteral::Create(C, llvm::APInt(C.getTypeSize(C.IntTy), (uint64_t) 1),
-                           C.IntTy, SourceLocation());
+  IntegerLiteral *IL = M.makeIntegerLiteral(1);
+
   BinaryOperator *B =
     M.makeAssignment(
        M.makeDereference(
@@ -234,13 +486,20 @@
         PredicateTy),
     PredicateTy);
   
-  UnaryOperator *UO = new (C) UnaryOperator(LValToRval, UO_LNot, C.IntTy,
-                                           VK_RValue, OK_Ordinary,
-                                           SourceLocation());
+  UnaryOperator *UO = new (C) UnaryOperator(
+      /* input= */ LValToRval,
+      /* opc= */ UO_LNot,
+      /* QualType= */ C.IntTy,
+      /* ExprValueKind= */ VK_RValue,
+      /* ExprObjectKind= */ OK_Ordinary, SourceLocation());
   
   // (5) Create the 'if' statement.
-  IfStmt *If = new (C) IfStmt(C, SourceLocation(), false, nullptr, nullptr,
-                              UO, CS);
+  IfStmt *If = new (C) IfStmt(C, SourceLocation(),
+                              /* IsConstexpr= */ false,
+                              /* init= */ nullptr,
+                              /* var= */ nullptr,
+                              /* cond= */ UO,
+                              /* then= */ CS);
   return If;
 }
 
@@ -370,8 +629,9 @@
   if (Name.startswith("OSAtomicCompareAndSwap") ||
       Name.startswith("objc_atomicCompareAndSwap")) {
     FF = create_OSAtomicCompareAndSwap;
-  }
-  else {
+  } else if (Name == "call_once" && D->getDeclContext()->isStdNamespace()) {
+    FF = create_call_once;
+  } else {
     FF = llvm::StringSwitch<FunctionFarmer>(Name)
           .Case("dispatch_sync", create_dispatch_sync)
           .Case("dispatch_once", create_dispatch_once)