Consumed analysis: add return_typestate attribute.
Patch by chris.wailes@gmail.com

Functions can now declare what state the consumable type the are returning will
be in. This is then used on the caller side and checked on the callee side.
Constructors now use this attribute instead of the 'consumes' attribute.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@189843 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/Consumed.cpp b/lib/Analysis/Consumed.cpp
index 59ebbb2..6ffdb23 100644
--- a/lib/Analysis/Consumed.cpp
+++ b/lib/Analysis/Consumed.cpp
@@ -31,6 +31,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/raw_ostream.h"
 
+// TODO: Add notes about the actual and expected state for 
 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
 // TODO: Warn about unreachable code.
 // TODO: Switch to using a bitmap to track unreachable blocks.
@@ -88,6 +89,19 @@
   return FunDecl->hasAttr<TestsUnconsumedAttr>();
 }
 
+static ConsumedState mapReturnTypestateAttrState(
+  const ReturnTypestateAttr *RTSAttr) {
+  
+  switch (RTSAttr->getState()) {
+  case ReturnTypestateAttr::Unknown:
+    return CS_Unknown;
+  case ReturnTypestateAttr::Unconsumed:
+    return CS_Unconsumed;
+  case ReturnTypestateAttr::Consumed:
+    return CS_Consumed;
+  }
+}
+
 static StringRef stateToString(ConsumedState State) {
   switch (State) {
   case consumed::CS_None:
@@ -256,6 +270,8 @@
   void forwardInfo(const Stmt *From, const Stmt *To);
   void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
   bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
+  void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
+                           QualType ReturnType);
   
 public:
 
@@ -272,6 +288,7 @@
   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
   void VisitMemberExpr(const MemberExpr *MExpr);
   void VisitParmVarDecl(const ParmVarDecl *Param);
+  void VisitReturnStmt(const ReturnStmt *Ret);
   void VisitUnaryOperator(const UnaryOperator *UOp);
   void VisitVarDecl(const VarDecl *Var);
 
@@ -373,6 +390,24 @@
           MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
 }
 
+void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
+                                              const FunctionDecl *Fun,
+                                              QualType ReturnType) {
+  if (isConsumableType(ReturnType)) {
+    
+    ConsumedState ReturnState;
+    
+    if (Fun->hasAttr<ReturnTypestateAttr>())
+      ReturnState = mapReturnTypestateAttrState(
+        Fun->getAttr<ReturnTypestateAttr>());
+    else
+      ReturnState = CS_Unknown;
+    
+    PropagationMap.insert(PairType(Call,
+      PropagationInfo(ReturnState)));
+  }
+}
+
 void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
   
   ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
@@ -469,6 +504,8 @@
         StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
       }
     }
+    
+    propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
   }
 }
 
@@ -483,8 +520,7 @@
   QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
   
   if (isConsumableType(ThisType)) {
-    if (Constructor->hasAttr<ConsumesAttr>() ||
-        Constructor->isDefaultConstructor()) {
+    if (Constructor->isDefaultConstructor()) {
       
       PropagationMap.insert(PairType(Call,
         PropagationInfo(consumed::CS_Consumed)));
@@ -513,8 +549,7 @@
         PropagationMap.insert(PairType(Call, Entry->second));
       
     } else {
-      PropagationMap.insert(PairType(Call,
-        PropagationInfo(consumed::CS_Unconsumed)));
+      propagateReturnType(Call, Constructor, ThisType);
     }
   }
 }
@@ -677,6 +712,24 @@
     StateMap->setState(Param, consumed::CS_Unknown);
 }
 
+void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
+  if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
+    InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
+    
+    if (Entry != PropagationMap.end()) {
+      assert(Entry->second.isState() || Entry->second.isVar());
+       
+      ConsumedState RetState = Entry->second.isState() ?
+        Entry->second.getState() : StateMap->getState(Entry->second.getVar());
+        
+      if (RetState != ExpectedState)
+        Analyzer.WarningsHandler.warnReturnTypestateMismatch(
+          Ret->getReturnLoc(), stateToString(ExpectedState),
+          stateToString(RetState));
+    }
+  }
+}
+
 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
   InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
   if (Entry == PropagationMap.end()) return;
@@ -997,6 +1050,53 @@
   
   if (!D) return;
   
+  // FIXME: This should be removed when template instantiation propagates
+  //        attributes at template specialization definition, not declaration.
+  //        When it is removed the test needs to be enabled in SemaDeclAttr.cpp.
+  QualType ReturnType;
+  if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
+    ASTContext &CurrContext = AC.getASTContext();
+    ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
+    
+  } else {
+    ReturnType = D->getCallResultType();
+  }
+  
+  // Determine the expected return value.
+  if (D->hasAttr<ReturnTypestateAttr>()) {
+    
+    ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
+    
+    const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
+    if (!RD || !RD->hasAttr<ConsumableAttr>()) {
+        // FIXME: This branch can be removed with the code above.
+        WarningsHandler.warnReturnTypestateForUnconsumableType(
+          RTSAttr->getLocation(), ReturnType.getAsString());
+        ExpectedReturnState = CS_None;
+        
+    } else {
+      switch (RTSAttr->getState()) {
+      case ReturnTypestateAttr::Unknown:
+        ExpectedReturnState = CS_Unknown;
+        break;
+        
+      case ReturnTypestateAttr::Unconsumed:
+        ExpectedReturnState = CS_Unconsumed;
+        break;
+        
+      case ReturnTypestateAttr::Consumed:
+        ExpectedReturnState = CS_Consumed;
+        break;
+      }
+    }
+    
+  } else if (isConsumableType(ReturnType)) {
+    ExpectedReturnState = CS_Unknown;
+      
+  } else {
+    ExpectedReturnState = CS_None;
+  }
+  
   BlockInfo = ConsumedBlockInfo(AC.getCFG());
   
   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();