Thread safety: shared vs. exclusive locks

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@139307 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/AnalysisBasedWarnings.cpp b/lib/Sema/AnalysisBasedWarnings.cpp
index 2d6a483..c155a9e 100644
--- a/lib/Sema/AnalysisBasedWarnings.cpp
+++ b/lib/Sema/AnalysisBasedWarnings.cpp
@@ -782,6 +782,16 @@
   }
 };
 
+enum LockKind {
+  LK_Shared,
+  LK_Exclusive
+};
+
+enum AccessKind {
+  AK_Read,
+  AK_Written
+};
+
 /// \brief This is a helper class that stores info about the most recent
 /// accquire of a Lock.
 ///
@@ -789,10 +799,19 @@
 struct LockData {
   SourceLocation AcquireLoc;
 
-  LockData(SourceLocation Loc) : AcquireLoc(Loc) {}
+  /// \brief LKind stores whether a lock is held shared or exclusively.
+  /// Note that this analysis does not currently support either re-entrant
+  /// locking or lock "upgrading" and "downgrading" between exclusive and
+  /// shared.
+  ///
+  /// FIXME: add support for re-entrant locking and lock up/downgrading
+  LockKind LKind;
+
+  LockData(SourceLocation AcquireLoc, LockKind LKind)
+    : AcquireLoc(AcquireLoc), LKind(LKind) {}
 
   bool operator==(const LockData &other) const {
-    return AcquireLoc == other.AcquireLoc;
+    return AcquireLoc == other.AcquireLoc && LKind == other.LKind;
   }
 
   bool operator!=(const LockData &other) const {
@@ -800,8 +819,9 @@
   }
 
   void Profile(llvm::FoldingSetNodeID &ID) const {
-    ID.AddInteger(AcquireLoc.getRawEncoding());
-  }
+      ID.AddInteger(AcquireLoc.getRawEncoding());
+      ID.AddInteger(LKind);
+    }
 };
 
 /// A Lockset maps each LockID (defined above) to information about how it has
@@ -820,10 +840,28 @@
 
   // Helper functions
   void removeLock(SourceLocation UnlockLoc, Expr *LockExp);
-  void addLock(SourceLocation LockLoc, Expr *LockExp);
+  void addLock(SourceLocation LockLoc, Expr *LockExp, LockKind LK);
   const ValueDecl *getValueDecl(Expr *Exp);
-  void checkAccess(Expr *Exp);
-  void checkDereference(Expr *Exp);
+  void warnIfLockNotHeld (const NamedDecl *D, Expr *Exp, AccessKind AK,
+                          LockID &Lock, unsigned DiagID);
+  void checkAccess(Expr *Exp, AccessKind AK);
+  void checkDereference(Expr *Exp, AccessKind AK);
+
+  template <class AttrType>
+  void addLocksToSet(LockKind LK, Attr *Attr, CXXMemberCallExpr *Exp);
+
+  /// \brief Returns true if the lockset contains a lock, regardless of whether
+  /// the lock is held exclusively or shared.
+  bool locksetContains(LockID Lock) {
+    return LSet.lookup(Lock);
+  }
+
+  /// \brief Returns true if the lockset contains a lock with the passed in
+  /// locktype.
+  bool locksetContains(LockID Lock, LockKind KindRequested) const {
+    const LockData *LockHeld = LSet.lookup(Lock);
+    return (LockHeld && KindRequested == LockHeld->LKind);
+  }
 
 public:
   BuildLockset(Sema &S, Lockset LS, Lockset::Factory &F)
@@ -843,13 +881,15 @@
 /// \brief Add a new lock to the lockset, warning if the lock is already there.
 /// \param LockLoc The source location of the acquire
 /// \param LockExp The lock expression corresponding to the lock to be added
-void BuildLockset::addLock(SourceLocation LockLoc, Expr *LockExp) {
+void BuildLockset::addLock(SourceLocation LockLoc, Expr *LockExp,
+                           LockKind LK) {
+  // FIXME: deal with acquired before/after annotations
   LockID Lock(LockExp);
-  LockData NewLockData(LockLoc);
+  LockData NewLockData(LockLoc, LK);
 
-  if (LSet.contains(Lock))
+  // FIXME: Don't always warn when we have support for reentrant locks.
+  if (locksetContains(Lock))
     S.Diag(LockLoc, diag::warn_double_lock) << Lock.getName();
-
   LSet = LocksetFactory.add(LSet, Lock, NewLockData);
 }
 
@@ -877,13 +917,33 @@
   return 0;
 }
 
+/// \brief Warn if the LSet does not contain a lock sufficient to protect access
+/// of at least the passed in AccessType.
+void BuildLockset::warnIfLockNotHeld(const NamedDecl *D, Expr *Exp,
+                                     AccessKind AK, LockID &Lock,
+                                     unsigned DiagID) {
+  switch (AK) {
+    case AK_Read:
+      if (!locksetContains(Lock))
+        S.Diag(Exp->getExprLoc(), DiagID)
+          << D->getName() << Lock.getName() << LK_Shared;
+      break;
+    case AK_Written :
+      if (!locksetContains(Lock, LK_Exclusive))
+        S.Diag(Exp->getExprLoc(), DiagID)
+          << D->getName() << Lock.getName() << LK_Exclusive;
+      break;
+  }
+}
+
+
 /// \brief This method identifies variable dereferences and checks pt_guarded_by
 /// and pt_guarded_var annotations. Note that we only check these annotations
 /// at the time a pointer is dereferenced.
 /// FIXME: We need to check for other types of pointer dereferences
 /// (e.g. [], ->) and deal with them here.
 /// \param Exp An expression that has been read or written.
-void BuildLockset::checkDereference(Expr *Exp) {
+void BuildLockset::checkDereference(Expr *Exp, AccessKind AK) {
   UnaryOperator *UO = dyn_cast<UnaryOperator>(Exp);
   if (!UO || UO->getOpcode() != clang::UO_Deref)
     return;
@@ -901,11 +961,9 @@
   for(unsigned i = 0, Size = ArgAttrs.size(); i < Size; ++i) {
     if (ArgAttrs[i]->getKind() != attr::PtGuardedBy)
       continue;
-    PtGuardedByAttr *PGBAttr = cast<PtGuardedByAttr>(ArgAttrs[i]);
+    const PtGuardedByAttr *PGBAttr = cast<PtGuardedByAttr>(ArgAttrs[i]);
     LockID Lock(PGBAttr->getArg());
-    if (!LSet.contains(Lock))
-      S.Diag(Exp->getExprLoc(), diag::warn_var_deref_requires_lock)
-        << D->getName() << Lock.getName();
+    warnIfLockNotHeld(D, Exp, AK, Lock, diag::warn_var_deref_requires_lock);
   }
 }
 
@@ -913,7 +971,7 @@
 /// Whenever we identify an access (read or write) of a DeclRefExpr or
 /// MemberExpr, we need to check whether there are any guarded_by or
 /// guarded_var attributes, and make sure we hold the appropriate locks.
-void BuildLockset::checkAccess(Expr *Exp) {
+void BuildLockset::checkAccess(Expr *Exp, AccessKind AK) {
   const ValueDecl *D = getValueDecl(Exp);
   if(!D || !D->hasAttrs())
     return;
@@ -926,11 +984,9 @@
   for(unsigned i = 0, Size = ArgAttrs.size(); i < Size; ++i) {
     if (ArgAttrs[i]->getKind() != attr::GuardedBy)
       continue;
-    GuardedByAttr *GBAttr = cast<GuardedByAttr>(ArgAttrs[i]);
+    const GuardedByAttr *GBAttr = cast<GuardedByAttr>(ArgAttrs[i]);
     LockID Lock(GBAttr->getArg());
-    if (!LSet.contains(Lock))
-      S.Diag(Exp->getExprLoc(), diag::warn_variable_requires_lock)
-        << D->getName() << Lock.getName();
+    warnIfLockNotHeld(D, Exp, AK, Lock, diag::warn_variable_requires_lock);
   }
 }
 
@@ -944,8 +1000,8 @@
     case clang::UO_PreDec:
     case clang::UO_PreInc: {
       Expr *SubExp = UO->getSubExpr()->IgnoreParenCasts();
-      checkAccess(SubExp);
-      checkDereference(SubExp);
+      checkAccess(SubExp, AK_Written);
+      checkDereference(SubExp, AK_Written);
       break;
     }
     default:
@@ -960,8 +1016,8 @@
   if (!BO->isAssignmentOp())
     return;
   Expr *LHSExp = BO->getLHS()->IgnoreParenCasts();
-  checkAccess(LHSExp);
-  checkDereference(LHSExp);
+  checkAccess(LHSExp, AK_Written);
+  checkDereference(LHSExp, AK_Written);
 }
 
 /// Whenever we do an LValue to Rvalue cast, we are reading a variable and
@@ -971,10 +1027,30 @@
   if (CE->getCastKind() != CK_LValueToRValue)
     return;
   Expr *SubExp = CE->getSubExpr()->IgnoreParenCasts();
-  checkAccess(SubExp);
-  checkDereference(SubExp);
+  checkAccess(SubExp, AK_Read);
+  checkDereference(SubExp, AK_Read);
 }
 
+/// \brief This function, parameterized by an attribute type, is used to add a
+/// set of locks specified as attribute arguments to the lockset.
+template <typename AttrType>
+void BuildLockset::addLocksToSet(LockKind LK, Attr *Attr,
+                                 CXXMemberCallExpr *Exp) {
+  typedef typename AttrType::args_iterator iterator_type;
+  SourceLocation ExpLocation = Exp->getExprLoc();
+  Expr *Parent = Exp->getImplicitObjectArgument();
+  AttrType *SpecificAttr = cast<AttrType>(Attr);
+
+  if (SpecificAttr->args_size() == 0) {
+    // The lock held is the "this" object.
+    addLock(ExpLocation, Parent, LK);
+    return;
+  }
+
+  for (iterator_type I = SpecificAttr->args_begin(),
+       E = SpecificAttr->args_end(); I != E; ++I)
+    addLock(ExpLocation, *I, LK);
+}
 
 /// \brief When visiting CXXMemberCallExprs we need to examine the attributes on
 /// the method that is being called and add, remove or check locks in the
@@ -998,22 +1074,16 @@
     Attr *Attr = ArgAttrs[i];
     switch (Attr->getKind()) {
       // When we encounter an exclusive lock function, we need to add the lock
-      // to our lockset.
-      case attr::ExclusiveLockFunction: {
-        ExclusiveLockFunctionAttr *ELFAttr =
-          cast<ExclusiveLockFunctionAttr>(Attr);
-
-        if (ELFAttr->args_size() == 0) {// The lock held is the "this" object.
-          addLock(ExpLocation, Parent);
-          break;
-        }
-
-        for (ExclusiveLockFunctionAttr::args_iterator I = ELFAttr->args_begin(),
-             E = ELFAttr->args_end(); I != E; ++I)
-          addLock(ExpLocation, *I);
-        // FIXME: acquired_after/acquired_before annotations
+      // to our lockset, marked as exclusive.
+      case attr::ExclusiveLockFunction:
+        addLocksToSet<ExclusiveLockFunctionAttr>(LK_Exclusive, Attr, Exp);
         break;
-      }
+
+      // When we encounter a shared lock function, we need to add the lock
+      // to our lockset, marked as not exclusive
+      case attr::SharedLockFunction:
+        addLocksToSet<SharedLockFunctionAttr>(LK_Shared, Attr, Exp);
+        break;
 
       // When we encounter an unlock function, we need to remove unlocked locks
       // from the lockset, and flag a warning if they are not there.
@@ -1066,6 +1136,35 @@
     S.Diag(I->first, I->second);
 }
 
+static Lockset warnIfNotInFirstSetOrNotSameKind(Sema &S, const Lockset LSet1,
+                                                const Lockset LSet2,
+                                                DiagList &Warnings,
+                                                Lockset Intersection,
+                                                Lockset::Factory &Fact) {
+  for (Lockset::iterator I = LSet2.begin(), E = LSet2.end(); I != E; ++I) {
+    const LockID &LSet2Lock = I.getKey();
+    const LockData &LSet2LockData = I.getData();
+    if (const LockData *LD = LSet1.lookup(LSet2Lock)) {
+      if (LD->LKind != LSet2LockData.LKind) {
+        PartialDiagnostic Warning =
+          S.PDiag(diag::warn_lock_exclusive_and_shared) << LSet2Lock.getName();
+        PartialDiagnostic Note =
+          S.PDiag(diag::note_lock_exclusive_and_shared) << LSet2Lock.getName();
+        Warnings.push_back(DelayedDiag(LSet2LockData.AcquireLoc, Warning));
+        Warnings.push_back(DelayedDiag(LD->AcquireLoc, Note));
+        if (LD->LKind != LK_Exclusive)
+          Intersection = Fact.add(Intersection, LSet2Lock, LSet2LockData);
+      }
+    } else {
+      PartialDiagnostic Warning =
+        S.PDiag(diag::warn_lock_not_released_in_scope) << LSet2Lock.getName();
+      Warnings.push_back(DelayedDiag(LSet2LockData.AcquireLoc, Warning));
+    }
+  }
+  return Intersection;
+}
+
+
 /// \brief Compute the intersection of two locksets and issue warnings for any
 /// locks in the symmetric difference.
 ///
@@ -1074,20 +1173,14 @@
 /// A; if () then B; else C; D; we need to check that the lockset after B and C
 /// are the same. In the event of a difference, we use the intersection of these
 /// two locksets at the start of D.
-static Lockset intersectAndWarn(Sema &S, Lockset LSet1, Lockset LSet2,
+static Lockset intersectAndWarn(Sema &S, const Lockset LSet1,
+                                const Lockset LSet2,
                                 Lockset::Factory &Fact) {
   Lockset Intersection = LSet1;
   DiagList Warnings;
 
-  for (Lockset::iterator I = LSet2.begin(), E = LSet2.end(); I != E; ++I) {
-    if (!LSet1.contains(I.getKey())) {
-      const LockID &MissingLock = I.getKey();
-      const LockData &MissingLockData = I.getData();
-      PartialDiagnostic Warning =
-        S.PDiag(diag::warn_lock_not_released_in_scope) << MissingLock.getName();
-      Warnings.push_back(DelayedDiag(MissingLockData.AcquireLoc, Warning));
-    }
-  }
+  Intersection = warnIfNotInFirstSetOrNotSameKind(S, LSet1, LSet2, Warnings,
+                                                  Intersection, Fact);
 
   for (Lockset::iterator I = LSet1.begin(), E = LSet1.end(); I != E; ++I) {
     if (!LSet2.contains(I.getKey())) {
@@ -1130,7 +1223,8 @@
 /// \param LoopReentrySet Locks held in the last CFG block of the loop
 static void warnBackEdgeUnequalLocksets(Sema &S, const Lockset LoopReentrySet,
                                         const Lockset LoopEntrySet,
-                                        SourceLocation FirstLocInLoop) {
+                                        SourceLocation FirstLocInLoop,
+                                        Lockset::Factory &Fact) {
   assert(FirstLocInLoop.isValid());
   DiagList Warnings;
 
@@ -1142,22 +1236,14 @@
       // We report this error at the location of the first statement in a loop
       PartialDiagnostic Warning =
         S.PDiag(diag::warn_expecting_lock_held_on_loop)
-          << MissingLock.getName();
+          << MissingLock.getName() << LK_Shared;
       Warnings.push_back(DelayedDiag(FirstLocInLoop, Warning));
     }
   }
 
   // Warn for locks held at the end of the loop, but not at the start.
-  for (Lockset::iterator I = LoopReentrySet.begin(), E = LoopReentrySet.end();
-       I != E; ++I) {
-    if (!LoopEntrySet.contains(I.getKey())) {
-      const LockID &MissingLock = I.getKey();
-      const LockData &MissingLockData = I.getData();
-      PartialDiagnostic Warning =
-        S.PDiag(diag::warn_lock_not_released_in_scope) << MissingLock.getName();
-      Warnings.push_back(DelayedDiag(MissingLockData.AcquireLoc, Warning));
-    }
-  }
+  warnIfNotInFirstSetOrNotSameKind(S, LoopEntrySet, LoopReentrySet, Warnings,
+                                   LoopReentrySet, Fact);
 
   EmitDiagnostics(S, Warnings);
 }
@@ -1250,13 +1336,15 @@
       SourceLocation FirstLoopLocation = getFirstStmtLocation(FirstLoopBlock);
 
       assert(FirstLoopLocation.isValid());
+
       // Fail gracefully in release code.
       if (!FirstLoopLocation.isValid())
         continue;
 
       Lockset PreLoop = EntryLocksets[FirstLoopBlock->getBlockID()];
       Lockset LoopEnd = ExitLocksets[CurrBlockID];
-      warnBackEdgeUnequalLocksets(S, LoopEnd, PreLoop, FirstLoopLocation);
+      warnBackEdgeUnequalLocksets(S, LoopEnd, PreLoop, FirstLoopLocation,
+                                  LocksetFactory);
     }
   }