Merge pull request #1240 from jskeet/validate_group

Validate that end-group tags match their corresponding start-group tags
diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
index 6ae0211..0e7cf04 100644
--- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
+++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
@@ -470,6 +470,52 @@
         }

 

         [Test]

+        public void SkipGroup_WrongEndGroupTag()

+        {

+            // Create an output stream with:

+            // Field 1: string "field 1"

+            // Start group 2

+            //   Field 3: fixed int32

+            // End group 4 (should give an error)

+            var stream = new MemoryStream();

+            var output = new CodedOutputStream(stream);

+            output.WriteTag(1, WireFormat.WireType.LengthDelimited);

+            output.WriteString("field 1");

+

+            // The outer group...

+            output.WriteTag(2, WireFormat.WireType.StartGroup);

+            output.WriteTag(3, WireFormat.WireType.Fixed32);

+            output.WriteFixed32(100);

+            output.WriteTag(4, WireFormat.WireType.EndGroup);

+            output.Flush();

+            stream.Position = 0;

+

+            // Now act like a generated client

+            var input = new CodedInputStream(stream);

+            Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());

+            Assert.AreEqual("field 1", input.ReadString());

+            Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());

+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);

+        }

+

+        [Test]

+        public void RogueEndGroupTag()

+        {

+            // If we have an end-group tag without a leading start-group tag, generated

+            // code will just call SkipLastField... so that should fail.

+

+            var stream = new MemoryStream();

+            var output = new CodedOutputStream(stream);

+            output.WriteTag(1, WireFormat.WireType.EndGroup);

+            output.Flush();

+            stream.Position = 0;

+

+            var input = new CodedInputStream(stream);

+            Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag());

+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);

+        }

+

+        [Test]

         public void EndOfStreamReachedWhileSkippingGroup()

         {

             var stream = new MemoryStream();

@@ -484,7 +530,7 @@
             // Now act like a generated client

             var input = new CodedInputStream(stream);

             input.ReadTag();

-            Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());

+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);

         }

 

         [Test]

@@ -506,7 +552,7 @@
             // Now act like a generated client

             var input = new CodedInputStream(stream);

             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());

-            Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());

+            Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);

         }

 

         [Test]

diff --git a/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs b/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs
index 14cc6d1..6706995 100644
--- a/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs
+++ b/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs
@@ -679,21 +679,20 @@
         /// for details; we may want to change this.

         /// </summary>

         [Test]

-        public void ExtraEndGroupSkipped()

+        public void ExtraEndGroupThrows()

         {

             var message = SampleMessages.CreateFullTestAllTypes();

             var stream = new MemoryStream();

             var output = new CodedOutputStream(stream);

 

-            output.WriteTag(100, WireFormat.WireType.EndGroup);

             output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32);

             output.WriteFixed32(123);

+            output.WriteTag(100, WireFormat.WireType.EndGroup);

 

             output.Flush();

 

             stream.Position = 0;

-            var parsed = TestAllTypes.Parser.ParseFrom(stream);

-            Assert.AreEqual(new TestAllTypes { SingleFixed32 = 123 }, parsed);

+            Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(stream));

         }

 

         [Test]

diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs
index 91bed8e..1c02d95 100644
--- a/csharp/src/Google.Protobuf/CodedInputStream.cs
+++ b/csharp/src/Google.Protobuf/CodedInputStream.cs
@@ -349,6 +349,14 @@
         /// This should be called directly after <see cref="ReadTag"/>, when

         /// the caller wishes to skip an unknown field.

         /// </summary>

+        /// <remarks>

+        /// This method throws <see cref="InvalidProtocolBufferException"/> if the last-read tag was an end-group tag.

+        /// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the

+        /// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly

+        /// resulting in an error if an end-group tag has not been paired with an earlier start-group tag.

+        /// </remarks>

+        /// <exception cref="InvalidProtocolBufferException">The last tag was an end-group tag</exception>

+        /// <exception cref="InvalidOperationException">The last read operation read to the end of the logical stream</exception>

         public void SkipLastField()

         {

             if (lastTag == 0)

@@ -358,11 +366,11 @@
             switch (WireFormat.GetTagWireType(lastTag))

             {

                 case WireFormat.WireType.StartGroup:

-                    SkipGroup();

+                    SkipGroup(lastTag);

                     break;

                 case WireFormat.WireType.EndGroup:

-                    // Just ignore; there's no data following the tag.

-                    break;

+                    throw new InvalidProtocolBufferException(

+                        "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");

                 case WireFormat.WireType.Fixed32:

                     ReadFixed32();

                     break;

@@ -379,7 +387,7 @@
             }

         }

 

-        private void SkipGroup()

+        private void SkipGroup(uint startGroupTag)

         {

             // Note: Currently we expect this to be the way that groups are read. We could put the recursion

             // depth changes into the ReadTag method instead, potentially...

@@ -389,16 +397,28 @@
                 throw InvalidProtocolBufferException.RecursionLimitExceeded();

             }

             uint tag;

-            do

+            while (true)

             {

                 tag = ReadTag();

                 if (tag == 0)

                 {

                     throw InvalidProtocolBufferException.TruncatedMessage();

                 }

+                // Can't call SkipLastField for this case- that would throw.

+                if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)

+                {

+                    break;

+                }

                 // This recursion will allow us to handle nested groups.

                 SkipLastField();

-            } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);

+            }

+            int startField = WireFormat.GetTagFieldNumber(startGroupTag);

+            int endField = WireFormat.GetTagFieldNumber(tag);

+            if (startField != endField)

+            {

+                throw new InvalidProtocolBufferException(

+                    $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");

+            }

             recursionDepth--;

         }