Added support to StmtIterator to traverse the size expression of a VLA type
declared in a sizeof.  For example:

 sizeof(int[foo()]);

the expression "foo()" is an expression that is executed during the evaluation
of sizeof.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@45043 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/AST/Expr.cpp b/AST/Expr.cpp
index 11aef7f..2276705 100644
--- a/AST/Expr.cpp
+++ b/AST/Expr.cpp
@@ -1036,10 +1036,15 @@
 
 // SizeOfAlignOfTypeExpr
 Stmt::child_iterator SizeOfAlignOfTypeExpr::child_begin() { 
-  return child_iterator(); 
+  // If the type is a VLA type (and not a typedef), the size expression of the
+  // VLA needs to be treated as an executable expression.
+  if (VariableArrayType* T = dyn_cast<VariableArrayType>(Ty.getTypePtr()))
+    return child_iterator(T);
+  else
+    return child_iterator(); 
 }
 Stmt::child_iterator SizeOfAlignOfTypeExpr::child_end() {
-  return child_iterator();
+  return child_iterator(); 
 }
 
 // ArraySubscriptExpr
diff --git a/AST/StmtIterator.cpp b/AST/StmtIterator.cpp
index ae6de11..94a5397 100644
--- a/AST/StmtIterator.cpp
+++ b/AST/StmtIterator.cpp
@@ -31,19 +31,21 @@
 
 void StmtIteratorBase::NextVA() {
   assert (getVAPtr());
-  assert (decl);
 
   VariableArrayType* p = getVAPtr();
   p = FindVA(p->getElementType().getTypePtr());
   setVAPtr(p);
 
-  if (!p) {
+  if (!p && decl) {
     if (VarDecl* VD = dyn_cast<VarDecl>(decl)) 
       if (VD->Init)
         return;
       
     NextDecl();
   }
+  else {
+    RawVAPtr = 0;
+  }    
 }
 
 void StmtIteratorBase::NextDecl(bool ImmediateAdvance) {
@@ -94,6 +96,12 @@
   NextDecl(false);
 }
 
+StmtIteratorBase::StmtIteratorBase(VariableArrayType* t)
+: decl(NULL), RawVAPtr(VASizeMode) {
+  RawVAPtr |= reinterpret_cast<uintptr_t>(t);
+}
+
+
 Stmt*& StmtIteratorBase::GetDeclExpr() const {
   if (VariableArrayType* VAPtr = getVAPtr()) {
     assert (VAPtr->SizeExpr);
diff --git a/include/clang/AST/StmtIterator.h b/include/clang/AST/StmtIterator.h
index 207e4ed..cd097c2 100644
--- a/include/clang/AST/StmtIterator.h
+++ b/include/clang/AST/StmtIterator.h
@@ -25,21 +25,21 @@
   
 class StmtIteratorBase {
 protected:
-  enum { DeclMode = 0x1 };
+  enum { DeclMode = 0x1, VASizeMode = 0x2, Flags = 0x3 };
   union { Stmt** stmt; ScopedDecl* decl; };
   uintptr_t RawVAPtr;
 
-  bool inDeclMode() const { 
-    return RawVAPtr & DeclMode ? true : false;
-  }
+  bool inDeclMode() const { return RawVAPtr & DeclMode ? true : false; }  
+  bool inVASizeMode() const { return RawVAPtr & VASizeMode ? true : false; }  
+  bool hasFlags() const { return RawVAPtr & Flags ? true : false; }
   
   VariableArrayType* getVAPtr() const {
-    return reinterpret_cast<VariableArrayType*>(RawVAPtr & ~DeclMode);
+    return reinterpret_cast<VariableArrayType*>(RawVAPtr & ~Flags);
   }
   
   void setVAPtr(VariableArrayType* P) {
-    assert (inDeclMode());
-    RawVAPtr = reinterpret_cast<uintptr_t>(P) | DeclMode;
+    assert (inDeclMode() || inVASizeMode());    
+    RawVAPtr = reinterpret_cast<uintptr_t>(P) | (RawVAPtr & Flags);
   }
   
   void NextDecl(bool ImmediateAdvance = true);
@@ -49,6 +49,7 @@
 
   StmtIteratorBase(Stmt** s) : stmt(s), RawVAPtr(0) {}
   StmtIteratorBase(ScopedDecl* d);
+  StmtIteratorBase(VariableArrayType* t);
   StmtIteratorBase() : stmt(NULL), RawVAPtr(0) {}
 };
   
@@ -64,14 +65,17 @@
   StmtIteratorImpl() {}                                                
   StmtIteratorImpl(Stmt** s) : StmtIteratorBase(s) {}
   StmtIteratorImpl(ScopedDecl* d) : StmtIteratorBase(d) {}
-
+  StmtIteratorImpl(VariableArrayType* t) : StmtIteratorBase(t) {}
   
   DERIVED& operator++() {
     if (inDeclMode()) {
       if (getVAPtr()) NextVA();
       else NextDecl();
     }
-    else ++stmt;
+    else if (inVASizeMode())
+      NextVA();            
+    else
+      ++stmt;
       
     return static_cast<DERIVED&>(*this);
   }
@@ -91,7 +95,7 @@
   }
   
   REFERENCE operator*() const { 
-    return (REFERENCE) (inDeclMode() ? GetDeclExpr() : *stmt);
+    return (REFERENCE) (hasFlags() ? GetDeclExpr() : *stmt);
   }
   
   REFERENCE operator->() const { return operator*(); }   
@@ -100,6 +104,7 @@
 struct StmtIterator : public StmtIteratorImpl<StmtIterator,Stmt*&> {
   explicit StmtIterator() : StmtIteratorImpl<StmtIterator,Stmt*&>() {}
   StmtIterator(Stmt** S) : StmtIteratorImpl<StmtIterator,Stmt*&>(S) {}
+  StmtIterator(VariableArrayType* t):StmtIteratorImpl<StmtIterator,Stmt*&>(t) {}
   StmtIterator(ScopedDecl* D) : StmtIteratorImpl<StmtIterator,Stmt*&>(D) {}
 };
 
@@ -107,7 +112,7 @@
                                                    const Stmt*> {
   explicit ConstStmtIterator() : 
     StmtIteratorImpl<ConstStmtIterator,const Stmt*>() {}
-
+                                                     
   ConstStmtIterator(const StmtIterator& RHS) : 
     StmtIteratorImpl<ConstStmtIterator,const Stmt*>(RHS) {}
 };