Template instantiation for switch statements


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@71916 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaTemplateInstantiateStmt.cpp b/lib/Sema/SemaTemplateInstantiateStmt.cpp
index d59f859..9252677 100644
--- a/lib/Sema/SemaTemplateInstantiateStmt.cpp
+++ b/lib/Sema/SemaTemplateInstantiateStmt.cpp
@@ -39,7 +39,10 @@
     OwningStmtResult VisitDeclStmt(DeclStmt *S);
     OwningStmtResult VisitNullStmt(NullStmt *S);
     OwningStmtResult VisitCompoundStmt(CompoundStmt *S);
+    OwningStmtResult VisitCaseStmt(CaseStmt *S);
+    OwningStmtResult VisitDefaultStmt(DefaultStmt *S);
     OwningStmtResult VisitIfStmt(IfStmt *S);
+    OwningStmtResult VisitSwitchStmt(SwitchStmt *S);
     OwningStmtResult VisitWhileStmt(WhileStmt *S);
     OwningStmtResult VisitDoStmt(DoStmt *S);
     OwningStmtResult VisitForStmt(ForStmt *S);
@@ -150,6 +153,50 @@
                                               S->getRBracLoc()));
 }
 
+Sema::OwningStmtResult TemplateStmtInstantiator::VisitCaseStmt(CaseStmt *S) {
+  // Instantiate left-hand case value.
+  OwningExprResult LHS = SemaRef.InstantiateExpr(S->getLHS(), TemplateArgs);
+  if (LHS.isInvalid())
+    return SemaRef.StmtError();
+
+  // Instantiate right-hand case value (for the GNU case-range extension).
+  OwningExprResult RHS = SemaRef.InstantiateExpr(S->getRHS(), TemplateArgs);
+  if (RHS.isInvalid())
+    return SemaRef.StmtError();
+
+  // Build the case statement.
+  OwningStmtResult Case = SemaRef.ActOnCaseStmt(S->getCaseLoc(),
+                                                move(LHS),
+                                                S->getEllipsisLoc(),
+                                                move(RHS),
+                                                S->getColonLoc());
+  if (Case.isInvalid())
+    return SemaRef.StmtError();
+
+  // Instantiate the statement following the case
+  OwningStmtResult SubStmt = SemaRef.InstantiateStmt(S->getSubStmt(), 
+                                                     TemplateArgs);
+  if (SubStmt.isInvalid())
+    return SemaRef.StmtError();
+
+  SemaRef.ActOnCaseStmtBody(Case.get(), move(SubStmt));
+  return move(Case);
+}
+
+Sema::OwningStmtResult 
+TemplateStmtInstantiator::VisitDefaultStmt(DefaultStmt *S) {
+  // Instantiate the statement following the default case
+  OwningStmtResult SubStmt = SemaRef.InstantiateStmt(S->getSubStmt(), 
+                                                     TemplateArgs);
+  if (SubStmt.isInvalid())
+    return SemaRef.StmtError();
+
+  return SemaRef.ActOnDefaultStmt(S->getDefaultLoc(), 
+                                  S->getColonLoc(),
+                                  move(SubStmt), 
+                                  /*CurScope=*/0);
+}
+
 Sema::OwningStmtResult TemplateStmtInstantiator::VisitIfStmt(IfStmt *S) {
   // Instantiate the condition
   OwningExprResult Cond = SemaRef.InstantiateExpr(S->getCond(), TemplateArgs);
@@ -170,6 +217,28 @@
                              S->getElseLoc(), move(Else));
 }
 
+Sema::OwningStmtResult 
+TemplateStmtInstantiator::VisitSwitchStmt(SwitchStmt *S) {
+  // Instantiate the condition.
+  OwningExprResult Cond = SemaRef.InstantiateExpr(S->getCond(), TemplateArgs);
+  if (Cond.isInvalid())
+    return SemaRef.StmtError();
+
+  // Start the switch statement itself.
+  OwningStmtResult Switch = SemaRef.ActOnStartOfSwitchStmt(move(Cond));
+  if (Switch.isInvalid())
+    return SemaRef.StmtError();
+
+  // Instantiate the body of the switch statement.
+  OwningStmtResult Body = SemaRef.InstantiateStmt(S->getBody(), TemplateArgs);
+  if (Body.isInvalid())
+    return SemaRef.StmtError();
+
+  // Complete the switch statement.
+  return SemaRef.ActOnFinishSwitchStmt(S->getSwitchLoc(), move(Switch),
+                                       move(Body));
+}
+
 Sema::OwningStmtResult TemplateStmtInstantiator::VisitWhileStmt(WhileStmt *S) {
   // Instantiate the condition
   OwningExprResult Cond = SemaRef.InstantiateExpr(S->getCond(), TemplateArgs);