Consumed Analysis:  Change callable_when so that it can take a list of states
that a function can be called in.  This reduced the total number of annotations
needed and makes writing more complicated behaviour less burdensome.
Patch by chriswails@gmail.com.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@191983 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/Consumed.cpp b/lib/Analysis/Consumed.cpp
index 7cd0290..ee8dd77 100644
--- a/lib/Analysis/Consumed.cpp
+++ b/lib/Analysis/Consumed.cpp
@@ -33,6 +33,8 @@
 
 // TODO: Add notes about the actual and expected state for 
 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
+// TODO: Adjust the parser and AttributesList class to support lists of
+//       identifiers.
 // TODO: Warn about unreachable code.
 // TODO: Switch to using a bitmap to track unreachable blocks.
 // TODO: Mark variables as Unknown going into while- or for-loops only if they
@@ -66,6 +68,37 @@
   llvm_unreachable("invalid enum");
 }
 
+static bool isCallableInState(const CallableWhenAttr *CWAttr,
+                              ConsumedState State) {
+  
+  CallableWhenAttr::callableState_iterator I = CWAttr->callableState_begin(),
+                                           E = CWAttr->callableState_end();
+  
+  for (; I != E; ++I) {
+    
+    ConsumedState MappedAttrState = CS_None;
+    
+    switch (*I) {
+    case CallableWhenAttr::Unknown:
+      MappedAttrState = CS_Unknown;
+      break;
+      
+    case CallableWhenAttr::Unconsumed:
+      MappedAttrState = CS_Unconsumed;
+      break;
+      
+    case CallableWhenAttr::Consumed:
+      MappedAttrState = CS_Consumed;
+      break;
+    }
+    
+    if (MappedAttrState == State)
+      return true;
+  }
+  
+  return false;
+}
+
 static bool isConsumableType(const QualType &QT) {
   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
     return RD->hasAttr<ConsumableAttr>();
@@ -174,6 +207,8 @@
     BinTestTy BinTest;
   };
   
+  QualType TempType;
+  
 public:
   PropagationInfo() : InfoType(IT_None) {}
   
@@ -208,7 +243,9 @@
     BinTest.RTest.TestsFor = RTestsFor;
   }
   
-  PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
+  PropagationInfo(ConsumedState State, QualType TempType)
+    : InfoType(IT_State), State(State), TempType(TempType) {}
+  
   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
   
   const ConsumedState & getState() const {
@@ -216,6 +253,11 @@
     return State;
   }
   
+  const QualType & getTempType() const {
+    assert(InfoType == IT_State);
+    return TempType;
+  }
+  
   const VarTestResult & getTest() const {
     assert(InfoType == IT_Test);
     return Test;
@@ -327,51 +369,38 @@
   }
 };
 
-// TODO: When we support CallableWhenConsumed this will have to check for
-//       the different attributes and change the behavior bellow. (Deferred)
 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
                                            const FunctionDecl *FunDecl,
                                            const CallExpr *Call) {
   
-  if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
+  if (!FunDecl->hasAttr<CallableWhenAttr>())
+    return;
+  
+  const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
   
   if (PInfo.isVar()) {
     const VarDecl *Var = PInfo.getVar();
+    ConsumedState VarState = StateMap->getState(Var);
     
-    switch (StateMap->getState(Var)) {
-    case CS_Consumed:
-      Analyzer.WarningsHandler.warnUseWhileConsumed(
-        FunDecl->getNameAsString(), Var->getNameAsString(),
-        Call->getExprLoc());
-      break;
+    assert(VarState != CS_None && "Invalid state");
     
-    case CS_Unknown:
-      Analyzer.WarningsHandler.warnUseInUnknownState(
-        FunDecl->getNameAsString(), Var->getNameAsString(),
-        Call->getExprLoc());
-      break;
-      
-    case CS_None:
-    case CS_Unconsumed:
-      break;
-    }
+    if (isCallableInState(CWAttr, VarState))
+      return;
     
-  } else {
-    switch (PInfo.getState()) {
-    case CS_Consumed:
-      Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
-        FunDecl->getNameAsString(), Call->getExprLoc());
-      break;
+    Analyzer.WarningsHandler.warnUseInInvalidState(
+      FunDecl->getNameAsString(), Var->getNameAsString(),
+      stateToString(VarState), Call->getExprLoc());
     
-    case CS_Unknown:
-      Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
-        FunDecl->getNameAsString(), Call->getExprLoc());
-      break;
-      
-    case CS_None:
-    case CS_Unconsumed:
-      break;
-    }
+  } else if (PInfo.isState()) {
+    
+    assert(PInfo.getState() != CS_None && "Invalid state");
+    
+    if (isCallableInState(CWAttr, PInfo.getState()))
+      return;
+    
+    Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
+      FunDecl->getNameAsString(), stateToString(PInfo.getState()),
+      Call->getExprLoc());
   }
 }
 
@@ -421,7 +450,7 @@
       ReturnState = mapConsumableAttrState(ReturnType);
     
     PropagationMap.insert(PairType(Call,
-      PropagationInfo(ReturnState)));
+      PropagationInfo(ReturnState, ReturnType)));
   }
 }
 
@@ -522,7 +551,11 @@
       }
     }
     
-    propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
+    QualType RetType = FunDecl->getCallResultType();
+    if (RetType->isReferenceType())
+      RetType = RetType->getPointeeType();
+    
+    propagateReturnType(Call, FunDecl, RetType);
   }
 }
 
@@ -540,7 +573,7 @@
     if (Constructor->isDefaultConstructor()) {
       
       PropagationMap.insert(PairType(Call,
-        PropagationInfo(consumed::CS_Consumed)));
+        PropagationInfo(consumed::CS_Consumed, ThisType)));
       
     } else if (Constructor->isMoveConstructor()) {
       
@@ -551,7 +584,7 @@
         const VarDecl* Var = PInfo.getVar();
         
         PropagationMap.insert(PairType(Call,
-          PropagationInfo(StateMap->getState(Var))));
+          PropagationInfo(StateMap->getState(Var), ThisType)));
         
         StateMap->setState(Var, consumed::CS_Consumed);
         
@@ -630,7 +663,8 @@
         
       } else if (!LPInfo.isVar() && RPInfo.isVar()) {
         PropagationMap.insert(PairType(Call,
-          PropagationInfo(StateMap->getState(RPInfo.getVar()))));
+          PropagationInfo(StateMap->getState(RPInfo.getVar()),
+                          LPInfo.getTempType())));
         
         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
         
@@ -648,27 +682,16 @@
         
         PropagationMap.insert(PairType(Call, LPInfo));
         
-      } else {
+      } else if (LPInfo.isState()) {
         PropagationMap.insert(PairType(Call,
-          PropagationInfo(consumed::CS_Unknown)));
+          PropagationInfo(consumed::CS_Unknown, LPInfo.getTempType())));
       }
       
     } else if (LEntry == PropagationMap.end() &&
                REntry != PropagationMap.end()) {
       
-      RPInfo = REntry->second;
-      
-      if (RPInfo.isVar()) {
-        const VarDecl *Var = RPInfo.getVar();
-        
-        PropagationMap.insert(PairType(Call,
-          PropagationInfo(StateMap->getState(Var))));
-        
-        StateMap->setState(Var, consumed::CS_Consumed);
-        
-      } else {
-        PropagationMap.insert(PairType(Call, RPInfo));
-      }
+      if (REntry->second.isVar())
+        StateMap->setState(REntry->second.getVar(), consumed::CS_Consumed);
     }
     
   } else {
@@ -776,6 +799,7 @@
   }
 }
 
+// TODO: See if I need to check for reference types here.
 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
   if (isConsumableType(Var->getType())) {
     if (Var->hasInit()) {
@@ -803,13 +827,12 @@
   
   if (VarState == CS_Unknown) {
     ThenStates->setState(Test.Var, Test.TestsFor);
-    if (ElseStates)
-      ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
+    ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
   
   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
     ThenStates->markUnreachable();
     
-  } else if (VarState == Test.TestsFor && ElseStates) {
+  } else if (VarState == Test.TestsFor) {
     ElseStates->markUnreachable();
   }
 }
@@ -832,31 +855,27 @@
         ThenStates->markUnreachable();
         
       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
-        if (RState == RTest.TestsFor) {
-          if (ElseStates)
-            ElseStates->markUnreachable();
-        } else {
+        if (RState == RTest.TestsFor)
+          ElseStates->markUnreachable();
+        else
           ThenStates->markUnreachable();
-        }
       }
       
     } else {
-      if (LState == CS_Unknown && ElseStates) {
+      if (LState == CS_Unknown) {
         ElseStates->setState(LTest.Var,
                              invertConsumedUnconsumed(LTest.TestsFor));
       
-      } else if (LState == LTest.TestsFor && ElseStates) {
+      } else if (LState == LTest.TestsFor) {
         ElseStates->markUnreachable();
         
       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
                  isKnownState(RState)) {
         
-        if (RState == RTest.TestsFor) {
-          if (ElseStates)
-            ElseStates->markUnreachable();
-        } else {
+        if (RState == RTest.TestsFor)
+          ElseStates->markUnreachable();
+        else
           ThenStates->markUnreachable();
-        }
       }
     }
   }
@@ -868,7 +887,7 @@
       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
         ThenStates->markUnreachable();
       
-    } else if (ElseStates) {
+    } else {
       if (RState == CS_Unknown)
         ElseStates->setState(RTest.Var,
                              invertConsumedUnconsumed(RTest.TestsFor));
@@ -1016,7 +1035,6 @@
   if (const IfStmt *IfNode =
     dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
     
-    bool HasElse = IfNode->getElse() != NULL;
     const Stmt *Cond = IfNode->getCond();
     
     PInfo = Visitor.getInfo(Cond);
@@ -1026,15 +1044,12 @@
     if (PInfo.isTest()) {
       CurrStates->setSource(Cond);
       FalseStates->setSource(Cond);
-      
-      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
-                         HasElse ? FalseStates : NULL);
+      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates, FalseStates);
       
     } else if (PInfo.isBinTest()) {
       CurrStates->setSource(PInfo.testSourceNode());
       FalseStates->setSource(PInfo.testSourceNode());
-      
-      splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
+      splitVarStateForIfBinOp(PInfo, CurrStates, FalseStates);
       
     } else {
       delete FalseStates;