blob: 225cf4daf2f57f04463b58237480afd2b7a14e21 [file] [log] [blame]
csharptest71f662c2011-05-20 15:15:34 -05001#region Copyright notice and license
2
3// Protocol Buffers - Google's data interchange format
4// Copyright 2008 Google Inc. All rights reserved.
5// http://github.com/jskeet/dotnet-protobufs/
6// Original C++/Java/Python code:
7// http://code.google.com/p/protobuf/
8//
9// Redistribution and use in source and binary forms, with or without
10// modification, are permitted provided that the following conditions are
11// met:
12//
13// * Redistributions of source code must retain the above copyright
14// notice, this list of conditions and the following disclaimer.
15// * Redistributions in binary form must reproduce the above
16// copyright notice, this list of conditions and the following disclaimer
17// in the documentation and/or other materials provided with the
18// distribution.
19// * Neither the name of Google Inc. nor the names of its
20// contributors may be used to endorse or promote products derived from
21// this software without specific prior written permission.
22//
23// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
35#endregion
36
37using System;
38using System.Collections.Generic;
39using System.IO;
40using System.Text;
41using Google.ProtocolBuffers.Descriptors;
42
43namespace Google.ProtocolBuffers
44{
45 /// <summary>
46 /// Readings and decodes protocol message fields.
47 /// </summary>
48 /// <remarks>
49 /// This class contains two kinds of methods: methods that read specific
50 /// protocol message constructs and field types (e.g. ReadTag and
51 /// ReadInt32) and methods that read low-level values (e.g.
52 /// ReadRawVarint32 and ReadRawBytes). If you are reading encoded protocol
53 /// messages, you should use the former methods, but if you are reading some
54 /// other format of your own design, use the latter. The names of the former
55 /// methods are taken from the protocol buffer type names, not .NET types.
56 /// (Hence ReadFloat instead of ReadSingle, and ReadBool instead of ReadBoolean.)
57 ///
58 /// TODO(jonskeet): Consider whether recursion and size limits shouldn't be readonly,
59 /// set at construction time.
60 /// </remarks>
61 public sealed class CodedInputStream
62 {
63 private readonly byte[] buffer;
64 private int bufferSize;
65 private int bufferSizeAfterLimit = 0;
66 private int bufferPos = 0;
67 private readonly Stream input;
68 private uint lastTag = 0;
69
70 internal const int DefaultRecursionLimit = 64;
71 internal const int DefaultSizeLimit = 64 << 20; // 64MB
72 public const int BufferSize = 4096;
73
74 /// <summary>
75 /// The total number of bytes read before the current buffer. The
76 /// total bytes read up to the current position can be computed as
77 /// totalBytesRetired + bufferPos.
78 /// </summary>
79 private int totalBytesRetired = 0;
80
81 /// <summary>
82 /// The absolute position of the end of the current message.
83 /// </summary>
84 private int currentLimit = int.MaxValue;
85
86 /// <summary>
87 /// <see cref="SetRecursionLimit"/>
88 /// </summary>
89 private int recursionDepth = 0;
90
91 private int recursionLimit = DefaultRecursionLimit;
92
93 /// <summary>
94 /// <see cref="SetSizeLimit"/>
95 /// </summary>
96 private int sizeLimit = DefaultSizeLimit;
97
98 #region Construction
99
100 /// <summary>
101 /// Creates a new CodedInputStream reading data from the given
102 /// stream.
103 /// </summary>
104 public static CodedInputStream CreateInstance(Stream input)
105 {
106 return new CodedInputStream(input);
107 }
108
109 /// <summary>
110 /// Creates a new CodedInputStream reading data from the given
111 /// byte array.
112 /// </summary>
113 public static CodedInputStream CreateInstance(byte[] buf)
114 {
115 return new CodedInputStream(buf, 0, buf.Length);
116 }
117
118 /// <summary>
119 /// Creates a new CodedInputStream that reads from the given
120 /// byte array slice.
121 /// </summary>
122 public static CodedInputStream CreateInstance(byte[] buf, int offset, int length)
123 {
124 return new CodedInputStream(buf, offset, length);
125 }
126
127 private CodedInputStream(byte[] buffer, int offset, int length)
128 {
129 this.buffer = buffer;
130 this.bufferPos = offset;
131 this.bufferSize = offset + length;
132 this.input = null;
133 }
134
135 private CodedInputStream(Stream input)
136 {
137 this.buffer = new byte[BufferSize];
138 this.bufferSize = 0;
139 this.input = input;
140 }
141
142 #endregion
143
144 #region Validation
145
146 /// <summary>
147 /// Verifies that the last call to ReadTag() returned the given tag value.
148 /// This is used to verify that a nested group ended with the correct
149 /// end tag.
150 /// </summary>
151 /// <exception cref="InvalidProtocolBufferException">The last
152 /// tag read was not the one specified</exception>
153 [CLSCompliant(false)]
154 public void CheckLastTagWas(uint value)
155 {
156 if (lastTag != value)
157 {
158 throw InvalidProtocolBufferException.InvalidEndTag();
159 }
160 }
161
162 #endregion
163
164 #region Reading of tags etc
165
166 /// <summary>
167 /// Attempt to read a field tag, returning 0 if we have reached the end
168 /// of the input data. Protocol message parsers use this to read tags,
169 /// since a protocol message may legally end wherever a tag occurs, and
170 /// zero is not a valid tag number.
171 /// </summary>
172 [CLSCompliant(false)]
173 public uint ReadTag()
174 {
175 if (IsAtEnd)
176 {
177 lastTag = 0;
178 return 0;
179 }
180
181 lastTag = ReadRawVarint32();
182 if (lastTag == 0)
183 {
184 // If we actually read zero, that's not a valid tag.
185 throw InvalidProtocolBufferException.InvalidTag();
186 }
187 return lastTag;
188 }
189
190 /// <summary>
191 /// Read a double field from the stream.
192 /// </summary>
193 public double ReadDouble()
194 {
195#if SILVERLIGHT2 || COMPACT_FRAMEWORK_35
196 byte[] bytes = ReadRawBytes(8);
197 return BitConverter.ToDouble(bytes, 0);
198#else
199 return BitConverter.Int64BitsToDouble((long) ReadRawLittleEndian64());
200#endif
201 }
202
203 /// <summary>
204 /// Read a float field from the stream.
205 /// </summary>
206 public float ReadFloat()
207 {
208 // TODO(jonskeet): Test this on different endiannesses
209 uint raw = ReadRawLittleEndian32();
210 byte[] rawBytes = BitConverter.GetBytes(raw);
211 return BitConverter.ToSingle(rawBytes, 0);
212 }
213
214 /// <summary>
215 /// Read a uint64 field from the stream.
216 /// </summary>
217 [CLSCompliant(false)]
218 public ulong ReadUInt64()
219 {
220 return ReadRawVarint64();
221 }
222
223 /// <summary>
224 /// Read an int64 field from the stream.
225 /// </summary>
226 public long ReadInt64()
227 {
228 return (long) ReadRawVarint64();
229 }
230
231 /// <summary>
232 /// Read an int32 field from the stream.
233 /// </summary>
234 public int ReadInt32()
235 {
236 return (int) ReadRawVarint32();
237 }
238
239 /// <summary>
240 /// Read a fixed64 field from the stream.
241 /// </summary>
242 [CLSCompliant(false)]
243 public ulong ReadFixed64()
244 {
245 return ReadRawLittleEndian64();
246 }
247
248 /// <summary>
249 /// Read a fixed32 field from the stream.
250 /// </summary>
251 [CLSCompliant(false)]
252 public uint ReadFixed32()
253 {
254 return ReadRawLittleEndian32();
255 }
256
257 /// <summary>
258 /// Read a bool field from the stream.
259 /// </summary>
260 public bool ReadBool()
261 {
262 return ReadRawVarint32() != 0;
263 }
264
265 /// <summary>
266 /// Reads a string field from the stream.
267 /// </summary>
268 public String ReadString()
269 {
270 int size = (int) ReadRawVarint32();
271 // No need to read any data for an empty string.
272 if (size == 0)
273 {
274 return "";
275 }
276 if (size <= bufferSize - bufferPos)
277 {
278 // Fast path: We already have the bytes in a contiguous buffer, so
279 // just copy directly from it.
280 String result = Encoding.UTF8.GetString(buffer, bufferPos, size);
281 bufferPos += size;
282 return result;
283 }
284 // Slow path: Build a byte array first then copy it.
285 return Encoding.UTF8.GetString(ReadRawBytes(size), 0, size);
286 }
287
288 /// <summary>
289 /// Reads a group field value from the stream.
290 /// </summary>
291 public void ReadGroup(int fieldNumber, IBuilderLite builder,
292 ExtensionRegistry extensionRegistry)
293 {
294 if (recursionDepth >= recursionLimit)
295 {
296 throw InvalidProtocolBufferException.RecursionLimitExceeded();
297 }
298 ++recursionDepth;
299 builder.WeakMergeFrom(this, extensionRegistry);
300 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
301 --recursionDepth;
302 }
303
304 /// <summary>
305 /// Reads a group field value from the stream and merges it into the given
306 /// UnknownFieldSet.
307 /// </summary>
308 [Obsolete]
309 public void ReadUnknownGroup(int fieldNumber, IBuilderLite builder)
310 {
311 if (recursionDepth >= recursionLimit)
312 {
313 throw InvalidProtocolBufferException.RecursionLimitExceeded();
314 }
315 ++recursionDepth;
316 builder.WeakMergeFrom(this);
317 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
318 --recursionDepth;
319 }
320
321 /// <summary>
322 /// Reads an embedded message field value from the stream.
323 /// </summary>
324 public void ReadMessage(IBuilderLite builder, ExtensionRegistry extensionRegistry)
325 {
326 int length = (int) ReadRawVarint32();
327 if (recursionDepth >= recursionLimit)
328 {
329 throw InvalidProtocolBufferException.RecursionLimitExceeded();
330 }
331 int oldLimit = PushLimit(length);
332 ++recursionDepth;
333 builder.WeakMergeFrom(this, extensionRegistry);
334 CheckLastTagWas(0);
335 --recursionDepth;
336 PopLimit(oldLimit);
337 }
338
339 /// <summary>
340 /// Reads a bytes field value from the stream.
341 /// </summary>
342 public ByteString ReadBytes()
343 {
344 int size = (int) ReadRawVarint32();
345 if (size < bufferSize - bufferPos && size > 0)
346 {
347 // Fast path: We already have the bytes in a contiguous buffer, so
348 // just copy directly from it.
349 ByteString result = ByteString.CopyFrom(buffer, bufferPos, size);
350 bufferPos += size;
351 return result;
352 }
353 else
354 {
355 // Slow path: Build a byte array first then copy it.
356 return ByteString.CopyFrom(ReadRawBytes(size));
357 }
358 }
359
360 /// <summary>
361 /// Reads a uint32 field value from the stream.
362 /// </summary>
363 [CLSCompliant(false)]
364 public uint ReadUInt32()
365 {
366 return ReadRawVarint32();
367 }
368
369 /// <summary>
370 /// Reads an enum field value from the stream. The caller is responsible
371 /// for converting the numeric value to an actual enum.
372 /// </summary>
373 public int ReadEnum()
374 {
375 return (int) ReadRawVarint32();
376 }
377
378 /// <summary>
379 /// Reads an sfixed32 field value from the stream.
380 /// </summary>
381 public int ReadSFixed32()
382 {
383 return (int) ReadRawLittleEndian32();
384 }
385
386 /// <summary>
387 /// Reads an sfixed64 field value from the stream.
388 /// </summary>
389 public long ReadSFixed64()
390 {
391 return (long) ReadRawLittleEndian64();
392 }
393
394 /// <summary>
395 /// Reads an sint32 field value from the stream.
396 /// </summary>
397 public int ReadSInt32()
398 {
399 return DecodeZigZag32(ReadRawVarint32());
400 }
401
402 /// <summary>
403 /// Reads an sint64 field value from the stream.
404 /// </summary>
405 public long ReadSInt64()
406 {
407 return DecodeZigZag64(ReadRawVarint64());
408 }
409
410 /// <summary>
411 /// Reads a field of any primitive type. Enums, groups and embedded
412 /// messages are not handled by this method.
413 /// </summary>
414 public object ReadPrimitiveField(FieldType fieldType)
415 {
416 switch (fieldType)
417 {
418 case FieldType.Double:
419 return ReadDouble();
420 case FieldType.Float:
421 return ReadFloat();
422 case FieldType.Int64:
423 return ReadInt64();
424 case FieldType.UInt64:
425 return ReadUInt64();
426 case FieldType.Int32:
427 return ReadInt32();
428 case FieldType.Fixed64:
429 return ReadFixed64();
430 case FieldType.Fixed32:
431 return ReadFixed32();
432 case FieldType.Bool:
433 return ReadBool();
434 case FieldType.String:
435 return ReadString();
436 case FieldType.Bytes:
437 return ReadBytes();
438 case FieldType.UInt32:
439 return ReadUInt32();
440 case FieldType.SFixed32:
441 return ReadSFixed32();
442 case FieldType.SFixed64:
443 return ReadSFixed64();
444 case FieldType.SInt32:
445 return ReadSInt32();
446 case FieldType.SInt64:
447 return ReadSInt64();
448 case FieldType.Group:
449 throw new ArgumentException("ReadPrimitiveField() cannot handle nested groups.");
450 case FieldType.Message:
451 throw new ArgumentException("ReadPrimitiveField() cannot handle embedded messages.");
452 // We don't handle enums because we don't know what to do if the
453 // value is not recognized.
454 case FieldType.Enum:
455 throw new ArgumentException("ReadPrimitiveField() cannot handle enums.");
456 default:
457 throw new ArgumentOutOfRangeException("Invalid field type " + fieldType);
458 }
459 }
460
461 #endregion
462
463 #region Underlying reading primitives
464
465 /// <summary>
466 /// Same code as ReadRawVarint32, but read each byte individually, checking for
467 /// buffer overflow.
468 /// </summary>
469 private uint SlowReadRawVarint32()
470 {
471 int tmp = ReadRawByte();
472 if (tmp < 128)
473 {
474 return (uint) tmp;
475 }
476 int result = tmp & 0x7f;
477 if ((tmp = ReadRawByte()) < 128)
478 {
479 result |= tmp << 7;
480 }
481 else
482 {
483 result |= (tmp & 0x7f) << 7;
484 if ((tmp = ReadRawByte()) < 128)
485 {
486 result |= tmp << 14;
487 }
488 else
489 {
490 result |= (tmp & 0x7f) << 14;
491 if ((tmp = ReadRawByte()) < 128)
492 {
493 result |= tmp << 21;
494 }
495 else
496 {
497 result |= (tmp & 0x7f) << 21;
498 result |= (tmp = ReadRawByte()) << 28;
499 if (tmp >= 128)
500 {
501 // Discard upper 32 bits.
502 for (int i = 0; i < 5; i++)
503 {
504 if (ReadRawByte() < 128) return (uint) result;
505 }
506 throw InvalidProtocolBufferException.MalformedVarint();
507 }
508 }
509 }
510 }
511 return (uint) result;
512 }
513
514 /// <summary>
515 /// Read a raw Varint from the stream. If larger than 32 bits, discard the upper bits.
516 /// This method is optimised for the case where we've got lots of data in the buffer.
517 /// That means we can check the size just once, then just read directly from the buffer
518 /// without constant rechecking of the buffer length.
519 /// </summary>
520 [CLSCompliant(false)]
521 public uint ReadRawVarint32()
522 {
523 if (bufferPos + 5 > bufferSize)
524 {
525 return SlowReadRawVarint32();
526 }
527
528 int tmp = buffer[bufferPos++];
529 if (tmp < 128)
530 {
531 return (uint) tmp;
532 }
533 int result = tmp & 0x7f;
534 if ((tmp = buffer[bufferPos++]) < 128)
535 {
536 result |= tmp << 7;
537 }
538 else
539 {
540 result |= (tmp & 0x7f) << 7;
541 if ((tmp = buffer[bufferPos++]) < 128)
542 {
543 result |= tmp << 14;
544 }
545 else
546 {
547 result |= (tmp & 0x7f) << 14;
548 if ((tmp = buffer[bufferPos++]) < 128)
549 {
550 result |= tmp << 21;
551 }
552 else
553 {
554 result |= (tmp & 0x7f) << 21;
555 result |= (tmp = buffer[bufferPos++]) << 28;
556 if (tmp >= 128)
557 {
558 // Discard upper 32 bits.
559 // Note that this has to use ReadRawByte() as we only ensure we've
560 // got at least 5 bytes at the start of the method. This lets us
561 // use the fast path in more cases, and we rarely hit this section of code.
562 for (int i = 0; i < 5; i++)
563 {
564 if (ReadRawByte() < 128) return (uint) result;
565 }
566 throw InvalidProtocolBufferException.MalformedVarint();
567 }
568 }
569 }
570 }
571 return (uint) result;
572 }
573
574 /// <summary>
575 /// Reads a varint from the input one byte at a time, so that it does not
576 /// read any bytes after the end of the varint. If you simply wrapped the
577 /// stream in a CodedInputStream and used ReadRawVarint32(Stream)}
578 /// then you would probably end up reading past the end of the varint since
579 /// CodedInputStream buffers its input.
580 /// </summary>
581 /// <param name="input"></param>
582 /// <returns></returns>
583 [CLSCompliant(false)]
584 public static uint ReadRawVarint32(Stream input)
585 {
586 int result = 0;
587 int offset = 0;
588 for (; offset < 32; offset += 7)
589 {
590 int b = input.ReadByte();
591 if (b == -1)
592 {
593 throw InvalidProtocolBufferException.TruncatedMessage();
594 }
595 result |= (b & 0x7f) << offset;
596 if ((b & 0x80) == 0)
597 {
598 return (uint) result;
599 }
600 }
601 // Keep reading up to 64 bits.
602 for (; offset < 64; offset += 7)
603 {
604 int b = input.ReadByte();
605 if (b == -1)
606 {
607 throw InvalidProtocolBufferException.TruncatedMessage();
608 }
609 if ((b & 0x80) == 0)
610 {
611 return (uint) result;
612 }
613 }
614 throw InvalidProtocolBufferException.MalformedVarint();
615 }
616
617 /// <summary>
618 /// Read a raw varint from the stream.
619 /// </summary>
620 [CLSCompliant(false)]
621 public ulong ReadRawVarint64()
622 {
623 int shift = 0;
624 ulong result = 0;
625 while (shift < 64)
626 {
627 byte b = ReadRawByte();
628 result |= (ulong) (b & 0x7F) << shift;
629 if ((b & 0x80) == 0)
630 {
631 return result;
632 }
633 shift += 7;
634 }
635 throw InvalidProtocolBufferException.MalformedVarint();
636 }
637
638 /// <summary>
639 /// Read a 32-bit little-endian integer from the stream.
640 /// </summary>
641 [CLSCompliant(false)]
642 public uint ReadRawLittleEndian32()
643 {
644 uint b1 = ReadRawByte();
645 uint b2 = ReadRawByte();
646 uint b3 = ReadRawByte();
647 uint b4 = ReadRawByte();
648 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24);
649 }
650
651 /// <summary>
652 /// Read a 64-bit little-endian integer from the stream.
653 /// </summary>
654 [CLSCompliant(false)]
655 public ulong ReadRawLittleEndian64()
656 {
657 ulong b1 = ReadRawByte();
658 ulong b2 = ReadRawByte();
659 ulong b3 = ReadRawByte();
660 ulong b4 = ReadRawByte();
661 ulong b5 = ReadRawByte();
662 ulong b6 = ReadRawByte();
663 ulong b7 = ReadRawByte();
664 ulong b8 = ReadRawByte();
665 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24)
666 | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56);
667 }
668
669 #endregion
670
671 /// <summary>
672 /// Decode a 32-bit value with ZigZag encoding.
673 /// </summary>
674 /// <remarks>
675 /// ZigZag encodes signed integers into values that can be efficiently
676 /// encoded with varint. (Otherwise, negative values must be
677 /// sign-extended to 64 bits to be varint encoded, thus always taking
678 /// 10 bytes on the wire.)
679 /// </remarks>
680 [CLSCompliant(false)]
681 public static int DecodeZigZag32(uint n)
682 {
683 return (int) (n >> 1) ^ -(int) (n & 1);
684 }
685
686 /// <summary>
687 /// Decode a 32-bit value with ZigZag encoding.
688 /// </summary>
689 /// <remarks>
690 /// ZigZag encodes signed integers into values that can be efficiently
691 /// encoded with varint. (Otherwise, negative values must be
692 /// sign-extended to 64 bits to be varint encoded, thus always taking
693 /// 10 bytes on the wire.)
694 /// </remarks>
695 [CLSCompliant(false)]
696 public static long DecodeZigZag64(ulong n)
697 {
698 return (long) (n >> 1) ^ -(long) (n & 1);
699 }
700
701 /// <summary>
702 /// Set the maximum message recursion depth.
703 /// </summary>
704 /// <remarks>
705 /// In order to prevent malicious
706 /// messages from causing stack overflows, CodedInputStream limits
707 /// how deeply messages may be nested. The default limit is 64.
708 /// </remarks>
709 public int SetRecursionLimit(int limit)
710 {
711 if (limit < 0)
712 {
713 throw new ArgumentOutOfRangeException("Recursion limit cannot be negative: " + limit);
714 }
715 int oldLimit = recursionLimit;
716 recursionLimit = limit;
717 return oldLimit;
718 }
719
720 /// <summary>
721 /// Set the maximum message size.
722 /// </summary>
723 /// <remarks>
724 /// In order to prevent malicious messages from exhausting memory or
725 /// causing integer overflows, CodedInputStream limits how large a message may be.
726 /// The default limit is 64MB. You should set this limit as small
727 /// as you can without harming your app's functionality. Note that
728 /// size limits only apply when reading from an InputStream, not
729 /// when constructed around a raw byte array (nor with ByteString.NewCodedInput).
730 /// If you want to read several messages from a single CodedInputStream, you
731 /// can call ResetSizeCounter() after each message to avoid hitting the
732 /// size limit.
733 /// </remarks>
734 public int SetSizeLimit(int limit)
735 {
736 if (limit < 0)
737 {
738 throw new ArgumentOutOfRangeException("Size limit cannot be negative: " + limit);
739 }
740 int oldLimit = sizeLimit;
741 sizeLimit = limit;
742 return oldLimit;
743 }
744
745 #region Internal reading and buffer management
746
747 /// <summary>
748 /// Resets the current size counter to zero (see SetSizeLimit).
749 /// </summary>
750 public void ResetSizeCounter()
751 {
752 totalBytesRetired = 0;
753 }
754
755 /// <summary>
756 /// Sets currentLimit to (current position) + byteLimit. This is called
757 /// when descending into a length-delimited embedded message. The previous
758 /// limit is returned.
759 /// </summary>
760 /// <returns>The old limit.</returns>
761 public int PushLimit(int byteLimit)
762 {
763 if (byteLimit < 0)
764 {
765 throw InvalidProtocolBufferException.NegativeSize();
766 }
767 byteLimit += totalBytesRetired + bufferPos;
768 int oldLimit = currentLimit;
769 if (byteLimit > oldLimit)
770 {
771 throw InvalidProtocolBufferException.TruncatedMessage();
772 }
773 currentLimit = byteLimit;
774
775 RecomputeBufferSizeAfterLimit();
776
777 return oldLimit;
778 }
779
780 private void RecomputeBufferSizeAfterLimit()
781 {
782 bufferSize += bufferSizeAfterLimit;
783 int bufferEnd = totalBytesRetired + bufferSize;
784 if (bufferEnd > currentLimit)
785 {
786 // Limit is in current buffer.
787 bufferSizeAfterLimit = bufferEnd - currentLimit;
788 bufferSize -= bufferSizeAfterLimit;
789 }
790 else
791 {
792 bufferSizeAfterLimit = 0;
793 }
794 }
795
796 /// <summary>
797 /// Discards the current limit, returning the previous limit.
798 /// </summary>
799 public void PopLimit(int oldLimit)
800 {
801 currentLimit = oldLimit;
802 RecomputeBufferSizeAfterLimit();
803 }
804
805 /// <summary>
806 /// Returns whether or not all the data before the limit has been read.
807 /// </summary>
808 /// <returns></returns>
809 public bool ReachedLimit
810 {
811 get
812 {
813 if (currentLimit == int.MaxValue)
814 {
815 return false;
816 }
817 int currentAbsolutePosition = totalBytesRetired + bufferPos;
818 return currentAbsolutePosition >= currentLimit;
819 }
820 }
821
822 /// <summary>
823 /// Returns true if the stream has reached the end of the input. This is the
824 /// case if either the end of the underlying input source has been reached or
825 /// the stream has reached a limit created using PushLimit.
826 /// </summary>
827 public bool IsAtEnd
828 {
829 get { return bufferPos == bufferSize && !RefillBuffer(false); }
830 }
831
832 /// <summary>
833 /// Called when buffer is empty to read more bytes from the
834 /// input. If <paramref name="mustSucceed"/> is true, RefillBuffer() gurantees that
835 /// either there will be at least one byte in the buffer when it returns
836 /// or it will throw an exception. If <paramref name="mustSucceed"/> is false,
837 /// RefillBuffer() returns false if no more bytes were available.
838 /// </summary>
839 /// <param name="mustSucceed"></param>
840 /// <returns></returns>
841 private bool RefillBuffer(bool mustSucceed)
842 {
843 if (bufferPos < bufferSize)
844 {
845 throw new InvalidOperationException("RefillBuffer() called when buffer wasn't empty.");
846 }
847
848 if (totalBytesRetired + bufferSize == currentLimit)
849 {
850 // Oops, we hit a limit.
851 if (mustSucceed)
852 {
853 throw InvalidProtocolBufferException.TruncatedMessage();
854 }
855 else
856 {
857 return false;
858 }
859 }
860
861 totalBytesRetired += bufferSize;
862
863 bufferPos = 0;
864 bufferSize = (input == null) ? 0 : input.Read(buffer, 0, buffer.Length);
865 if (bufferSize < 0)
866 {
867 throw new InvalidOperationException("Stream.Read returned a negative count");
868 }
869 if (bufferSize == 0)
870 {
871 if (mustSucceed)
872 {
873 throw InvalidProtocolBufferException.TruncatedMessage();
874 }
875 else
876 {
877 return false;
878 }
879 }
880 else
881 {
882 RecomputeBufferSizeAfterLimit();
883 int totalBytesRead =
884 totalBytesRetired + bufferSize + bufferSizeAfterLimit;
885 if (totalBytesRead > sizeLimit || totalBytesRead < 0)
886 {
887 throw InvalidProtocolBufferException.SizeLimitExceeded();
888 }
889 return true;
890 }
891 }
892
893 /// <summary>
894 /// Read one byte from the input.
895 /// </summary>
896 /// <exception cref="InvalidProtocolBufferException">
897 /// the end of the stream or the current limit was reached
898 /// </exception>
899 public byte ReadRawByte()
900 {
901 if (bufferPos == bufferSize)
902 {
903 RefillBuffer(true);
904 }
905 return buffer[bufferPos++];
906 }
907
908 /// <summary>
909 /// Read a fixed size of bytes from the input.
910 /// </summary>
911 /// <exception cref="InvalidProtocolBufferException">
912 /// the end of the stream or the current limit was reached
913 /// </exception>
914 public byte[] ReadRawBytes(int size)
915 {
916 if (size < 0)
917 {
918 throw InvalidProtocolBufferException.NegativeSize();
919 }
920
921 if (totalBytesRetired + bufferPos + size > currentLimit)
922 {
923 // Read to the end of the stream anyway.
924 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
925 // Then fail.
926 throw InvalidProtocolBufferException.TruncatedMessage();
927 }
928
929 if (size <= bufferSize - bufferPos)
930 {
931 // We have all the bytes we need already.
932 byte[] bytes = new byte[size];
933 Array.Copy(buffer, bufferPos, bytes, 0, size);
934 bufferPos += size;
935 return bytes;
936 }
937 else if (size < BufferSize)
938 {
939 // Reading more bytes than are in the buffer, but not an excessive number
940 // of bytes. We can safely allocate the resulting array ahead of time.
941
942 // First copy what we have.
943 byte[] bytes = new byte[size];
944 int pos = bufferSize - bufferPos;
945 Array.Copy(buffer, bufferPos, bytes, 0, pos);
946 bufferPos = bufferSize;
947
948 // We want to use RefillBuffer() and then copy from the buffer into our
949 // byte array rather than reading directly into our byte array because
950 // the input may be unbuffered.
951 RefillBuffer(true);
952
953 while (size - pos > bufferSize)
954 {
955 Array.Copy(buffer, 0, bytes, pos, bufferSize);
956 pos += bufferSize;
957 bufferPos = bufferSize;
958 RefillBuffer(true);
959 }
960
961 Array.Copy(buffer, 0, bytes, pos, size - pos);
962 bufferPos = size - pos;
963
964 return bytes;
965 }
966 else
967 {
968 // The size is very large. For security reasons, we can't allocate the
969 // entire byte array yet. The size comes directly from the input, so a
970 // maliciously-crafted message could provide a bogus very large size in
971 // order to trick the app into allocating a lot of memory. We avoid this
972 // by allocating and reading only a small chunk at a time, so that the
973 // malicious message must actually *be* extremely large to cause
974 // problems. Meanwhile, we limit the allowed size of a message elsewhere.
975
976 // Remember the buffer markers since we'll have to copy the bytes out of
977 // it later.
978 int originalBufferPos = bufferPos;
979 int originalBufferSize = bufferSize;
980
981 // Mark the current buffer consumed.
982 totalBytesRetired += bufferSize;
983 bufferPos = 0;
984 bufferSize = 0;
985
986 // Read all the rest of the bytes we need.
987 int sizeLeft = size - (originalBufferSize - originalBufferPos);
988 List<byte[]> chunks = new List<byte[]>();
989
990 while (sizeLeft > 0)
991 {
992 byte[] chunk = new byte[Math.Min(sizeLeft, BufferSize)];
993 int pos = 0;
994 while (pos < chunk.Length)
995 {
996 int n = (input == null) ? -1 : input.Read(chunk, pos, chunk.Length - pos);
997 if (n <= 0)
998 {
999 throw InvalidProtocolBufferException.TruncatedMessage();
1000 }
1001 totalBytesRetired += n;
1002 pos += n;
1003 }
1004 sizeLeft -= chunk.Length;
1005 chunks.Add(chunk);
1006 }
1007
1008 // OK, got everything. Now concatenate it all into one buffer.
1009 byte[] bytes = new byte[size];
1010
1011 // Start by copying the leftover bytes from this.buffer.
1012 int newPos = originalBufferSize - originalBufferPos;
1013 Array.Copy(buffer, originalBufferPos, bytes, 0, newPos);
1014
1015 // And now all the chunks.
1016 foreach (byte[] chunk in chunks)
1017 {
1018 Array.Copy(chunk, 0, bytes, newPos, chunk.Length);
1019 newPos += chunk.Length;
1020 }
1021
1022 // Done.
1023 return bytes;
1024 }
1025 }
1026
1027 /// <summary>
1028 /// Reads and discards a single field, given its tag value.
1029 /// </summary>
1030 /// <returns>false if the tag is an end-group tag, in which case
1031 /// nothing is skipped. Otherwise, returns true.</returns>
1032 [CLSCompliant(false)]
1033 public bool SkipField(uint tag)
1034 {
1035 switch (WireFormat.GetTagWireType(tag))
1036 {
1037 case WireFormat.WireType.Varint:
1038 ReadInt32();
1039 return true;
1040 case WireFormat.WireType.Fixed64:
1041 ReadRawLittleEndian64();
1042 return true;
1043 case WireFormat.WireType.LengthDelimited:
1044 SkipRawBytes((int) ReadRawVarint32());
1045 return true;
1046 case WireFormat.WireType.StartGroup:
1047 SkipMessage();
1048 CheckLastTagWas(
1049 WireFormat.MakeTag(WireFormat.GetTagFieldNumber(tag),
1050 WireFormat.WireType.EndGroup));
1051 return true;
1052 case WireFormat.WireType.EndGroup:
1053 return false;
1054 case WireFormat.WireType.Fixed32:
1055 ReadRawLittleEndian32();
1056 return true;
1057 default:
1058 throw InvalidProtocolBufferException.InvalidWireType();
1059 }
1060 }
1061
1062 /// <summary>
1063 /// Reads and discards an entire message. This will read either until EOF
1064 /// or until an endgroup tag, whichever comes first.
1065 /// </summary>
1066 public void SkipMessage()
1067 {
1068 while (true)
1069 {
1070 uint tag = ReadTag();
1071 if (tag == 0 || !SkipField(tag))
1072 {
1073 return;
1074 }
1075 }
1076 }
1077
1078 /// <summary>
1079 /// Reads and discards <paramref name="size"/> bytes.
1080 /// </summary>
1081 /// <exception cref="InvalidProtocolBufferException">the end of the stream
1082 /// or the current limit was reached</exception>
1083 public void SkipRawBytes(int size)
1084 {
1085 if (size < 0)
1086 {
1087 throw InvalidProtocolBufferException.NegativeSize();
1088 }
1089
1090 if (totalBytesRetired + bufferPos + size > currentLimit)
1091 {
1092 // Read to the end of the stream anyway.
1093 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
1094 // Then fail.
1095 throw InvalidProtocolBufferException.TruncatedMessage();
1096 }
1097
1098 if (size <= bufferSize - bufferPos)
1099 {
1100 // We have all the bytes we need already.
1101 bufferPos += size;
1102 }
1103 else
1104 {
1105 // Skipping more bytes than are in the buffer. First skip what we have.
1106 int pos = bufferSize - bufferPos;
1107 totalBytesRetired += pos;
1108 bufferPos = 0;
1109 bufferSize = 0;
1110
1111 // Then skip directly from the InputStream for the rest.
1112 if (pos < size)
1113 {
1114 if (input == null)
1115 {
1116 throw InvalidProtocolBufferException.TruncatedMessage();
1117 }
1118 SkipImpl(size - pos);
1119 totalBytesRetired += size - pos;
1120 }
1121 }
1122 }
1123
1124 /// <summary>
1125 /// Abstraction of skipping to cope with streams which can't really skip.
1126 /// </summary>
1127 private void SkipImpl(int amountToSkip)
1128 {
1129 if (input.CanSeek)
1130 {
1131 long previousPosition = input.Position;
1132 input.Position += amountToSkip;
1133 if (input.Position != previousPosition + amountToSkip)
1134 {
1135 throw InvalidProtocolBufferException.TruncatedMessage();
1136 }
1137 }
1138 else
1139 {
1140 byte[] skipBuffer = new byte[1024];
1141 while (amountToSkip > 0)
1142 {
1143 int bytesRead = input.Read(skipBuffer, 0, skipBuffer.Length);
1144 if (bytesRead <= 0)
1145 {
1146 throw InvalidProtocolBufferException.TruncatedMessage();
1147 }
1148 amountToSkip -= bytesRead;
1149 }
1150 }
1151 }
1152
1153 #endregion
1154 }
1155}