Modified StmtIterator to now include visiting the initialization expression for EnumConstantDecls.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@43366 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/AST/StmtIterator.cpp b/AST/StmtIterator.cpp
index 2d198c0..961ca50 100644
--- a/AST/StmtIterator.cpp
+++ b/AST/StmtIterator.cpp
@@ -17,11 +17,23 @@
 
 using namespace clang;
 
+static inline bool declHasExpr(ScopedDecl *decl) {
+  if (VarDecl* D = dyn_cast<VarDecl>(decl))
+    if (D->getInit())
+      return true;
+  
+  if (EnumConstantDecl* D = dyn_cast<EnumConstantDecl>(decl))
+    if (D->getInitExpr())
+      return true;
+  
+  return false;  
+}
+
 void StmtIteratorBase::NextDecl() {
   assert (FirstDecl && Ptr.D);
 
   do Ptr.D = Ptr.D->getNextDeclarator();
-  while (Ptr.D != NULL && !isa<VarDecl>(Ptr.D));
+  while (Ptr.D != NULL && !declHasExpr(Ptr.D));
   
   if (Ptr.D == NULL) FirstDecl = NULL;
 }
@@ -29,12 +41,8 @@
 StmtIteratorBase::StmtIteratorBase(ScopedDecl* d) {
   assert (d);
   
-  while (d != NULL) {
-    if (VarDecl* V = dyn_cast<VarDecl>(d))
-      if (V->getInit()) break;
-    
+  while (d != NULL && !declHasExpr(d))
     d = d->getNextDeclarator();
-  }
   
   FirstDecl = d;
   Ptr.D = d;
@@ -61,6 +69,11 @@
   Ptr.D = lastVD;
 }
 
-Stmt*& StmtIteratorBase::GetInitializer() const {
-  return reinterpret_cast<Stmt*&>(cast<VarDecl>(Ptr.D)->Init);
+Stmt*& StmtIteratorBase::GetDeclExpr() const {
+  if (VarDecl* D = dyn_cast<VarDecl>(Ptr.D))
+    return reinterpret_cast<Stmt*&>(D->Init);
+  else {
+    EnumConstantDecl* D = cast<EnumConstantDecl>(Ptr.D);
+    return reinterpret_cast<Stmt*&>(D->Init);
+  }
 }