Transform the AST to create a correct representation of the cases inside a switch statement
diff --git a/pycparser/_c_ast.cfg b/pycparser/_c_ast.cfg
index 9feaf1a..ca2379b 100644
--- a/pycparser/_c_ast.cfg
+++ b/pycparser/_c_ast.cfg
@@ -25,7 +25,7 @@
 

 Break: []

 

-Case: [expr*, stmt*]

+Case: [expr*, stmts**]

 

 Cast: [to_type*, expr*]

 

@@ -59,7 +59,7 @@
 

 DeclList: [decls**]

 

-Default: [stmt*]

+Default: [stmts**]

 

 DoWhile: [cond*, stmt*]

 

diff --git a/pycparser/ast_transforms.py b/pycparser/ast_transforms.py
new file mode 100644
index 0000000..b30ae3c
--- /dev/null
+++ b/pycparser/ast_transforms.py
@@ -0,0 +1,105 @@
+#------------------------------------------------------------------------------
+# pycparser: ast_transforms.py
+#
+# Some utilities used by the parser to create a friendlier AST.
+#
+# Copyright (C) 2008-2012, Eli Bendersky
+# License: BSD
+#------------------------------------------------------------------------------
+
+from . import c_ast
+
+
+def fix_switch_cases(switch_node):
+    """ The 'case' statements in a 'switch' come out of parsing with one
+        child node, so subsequent statements are just tucked to the parent
+        Compound. Additionally, consecutive (fall-through) case statements
+        come out messy. This is a peculiarity of the C grammar. The following:
+
+            switch (myvar) {
+                case 10:
+                    k = 10;
+                    p = k + 1;
+                    return 10;
+                case 20:
+                case 30:
+                    return 20;
+                default:
+                    break;
+            }
+
+        Creates this tree (pseudo-dump):
+
+            Switch
+                ID: myvar
+                Compound:
+                    Case 10:
+                        k = 10
+                    p = k + 1
+                    return 10
+                    Case 20:
+                        Case 30:
+                            return 20
+                    Default:
+                        break
+
+        The goal of this transform it to fix this mess, turning it into the
+        following:
+
+            Switch
+                ID: myvar
+                Compound:
+                    Case 10:
+                        k = 10
+                        p = k + 1
+                        return 10
+                    Case 20:
+                    Case 30:
+                        return 20
+                    Default:
+                        break
+
+        A fixed AST node is returned. The argument may be modified.
+    """
+    assert isinstance(switch_node, c_ast.Switch)
+    if not isinstance(switch_node.stmt, c_ast.Compound):
+        return switch_node
+
+    # The new Compound child for the Switch, which will collect children in the
+    # correct order
+    new_compound = c_ast.Compound([], switch_node.stmt.coord)
+
+    # The last Case/Default node
+    last_case = None
+
+    # Goes over the children of the Compound below the Switch, adding them
+    # either directly below new_compound or below the last Case as appropriate
+    for child in switch_node.stmt.block_items:
+        if isinstance(child, (c_ast.Case, c_ast.Default)):
+            # If it's a Case/Default:
+            # 1. Add it to the Compound and mark as "last case"
+            # 2. If its immediate child is also a Case or Default, promote it
+            #    to a sibling.
+            new_compound.block_items.append(child)
+            _extract_nested_case(child, new_compound.block_items)
+            last_case = new_compound.block_items[-1]
+        else:
+            # Other statements are added as childrent to the last case, if it
+            # exists.
+            if last_case is None:
+                new_compound.block_items.append(child)
+            else:
+                last_case.stmts.append(child)
+
+    switch_node.stmt = new_compound
+    return switch_node
+
+
+def _extract_nested_case(case_node, stmts_list):
+    """ Recursively extract consecutive Case statements that are made nested
+        by the parser and add them to the stmts_list.
+    """
+    if isinstance(case_node.stmts[0], (c_ast.Case, c_ast.Default)):
+        stmts_list.append(case_node.stmts.pop())
+        _extract_nested_case(stmts_list[-1], stmts_list)
+
diff --git a/pycparser/c_ast.py b/pycparser/c_ast.py
index 5868b9b..a1c92fb 100644
--- a/pycparser/c_ast.py
+++ b/pycparser/c_ast.py
@@ -194,15 +194,16 @@
     attr_names = ()
 
 class Case(Node):
-    def __init__(self, expr, stmt, coord=None):
+    def __init__(self, expr, stmts, coord=None):
         self.expr = expr
-        self.stmt = stmt
+        self.stmts = stmts
         self.coord = coord
 
     def children(self):
         nodelist = []
         if self.expr is not None: nodelist.append(("expr", self.expr))
-        if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+        for i, child in enumerate(self.stmts or []):
+            nodelist.append(("stmts[%d]" % i, child))
         return tuple(nodelist)
 
     attr_names = ()
@@ -303,13 +304,14 @@
     attr_names = ()
 
 class Default(Node):
-    def __init__(self, stmt, coord=None):
-        self.stmt = stmt
+    def __init__(self, stmts, coord=None):
+        self.stmts = stmts
         self.coord = coord
 
     def children(self):
         nodelist = []
-        if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+        for i, child in enumerate(self.stmts or []):
+            nodelist.append(("stmts[%d]" % i, child))
         return tuple(nodelist)
 
     attr_names = ()
diff --git a/pycparser/c_parser.py b/pycparser/c_parser.py
index a53ccd4..cd21c3d 100644
--- a/pycparser/c_parser.py
+++ b/pycparser/c_parser.py
@@ -13,6 +13,7 @@
 from . import c_ast
 from .c_lexer import CLexer
 from .plyparser import PLYParser, Coord, ParseError
+from .ast_transforms import fix_switch_cases
 
 
 class CParser(PLYParser):    
@@ -1085,11 +1086,11 @@
     
     def p_labeled_statement_2(self, p):
         """ labeled_statement : CASE constant_expression COLON statement """
-        p[0] = c_ast.Case(p[2], p[4], self._coord(p.lineno(1)))
+        p[0] = c_ast.Case(p[2], [p[4]], self._coord(p.lineno(1)))
         
     def p_labeled_statement_3(self, p):
         """ labeled_statement : DEFAULT COLON statement """
-        p[0] = c_ast.Default(p[3], self._coord(p.lineno(1)))
+        p[0] = c_ast.Default([p[3]], self._coord(p.lineno(1)))
         
     def p_selection_statement_1(self, p):
         """ selection_statement : IF LPAREN expression RPAREN statement """
@@ -1101,7 +1102,8 @@
     
     def p_selection_statement_3(self, p):
         """ selection_statement : SWITCH LPAREN expression RPAREN statement """
-        p[0] = c_ast.Switch(p[3], p[5], self._coord(p.lineno(1)))
+        p[0] = fix_switch_cases(
+                c_ast.Switch(p[3], p[5], self._coord(p.lineno(1))))
     
     def p_iteration_statement_1(self, p):
         """ iteration_statement : WHILE LPAREN expression RPAREN statement """
diff --git a/tests/test_c_parser.py b/tests/test_c_parser.py
index c4292d7..dbf7533 100644
--- a/tests/test_c_parser.py
+++ b/tests/test_c_parser.py
@@ -1284,6 +1284,14 @@
         self.assert_num_klass_nodes(ps1, Return, 1)
 
     def test_switch_statement(self):
+        def assert_case_node(node, const_value):
+            self.failUnless(isinstance(node, Case))
+            self.failUnless(isinstance(node.expr, Constant))
+            self.assertEqual(node.expr.value, const_value)
+
+        def assert_default_node(node):
+            self.failUnless(isinstance(node, Default))
+
         s1 = r'''
         int foo(void) {
             switch (myvar) {
@@ -1301,7 +1309,46 @@
         }
         '''
         ps1 = self.parse(s1)
-        #~ ps1.show()
+        switch = ps1.ext[0].body.block_items[0]
+
+        block = switch.stmt.block_items
+        assert_case_node(block[0], '10')
+        self.assertEqual(len(block[0].stmts), 3)
+        assert_case_node(block[1], '20')
+        self.assertEqual(len(block[1].stmts), 0)
+        assert_case_node(block[2], '30')
+        self.assertEqual(len(block[2].stmts), 1)
+        assert_default_node(block[3])
+
+        s2 = r'''
+        int foo(void) {
+            switch (myvar) {
+                default:
+                    joe = moe;
+                    return 10;
+                case 10:
+                case 20:
+                case 30:
+                case 40:
+                    break;
+            }
+            return 0;
+        }
+        '''
+        ps2 = self.parse(s2)
+        switch = ps2.ext[0].body.block_items[0]
+
+        block = switch.stmt.block_items
+        assert_default_node(block[0])
+        self.assertEqual(len(block[0].stmts), 2)
+        assert_case_node(block[1], '10')
+        self.assertEqual(len(block[1].stmts), 0)
+        assert_case_node(block[2], '20')
+        self.assertEqual(len(block[1].stmts), 0)
+        assert_case_node(block[3], '30')
+        self.assertEqual(len(block[1].stmts), 0)
+        assert_case_node(block[4], '40')
+        self.assertEqual(len(block[4].stmts), 1)
 
     def test_for_statement(self):
         s2 = r'''