blob: 89d8e33edfb4b68bcb8e069911ba9ffb19dce5d3 [file] [log] [blame]
Jon Skeet60c059b2008-10-23 21:17:56 +01001// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc. All rights reserved.
3// http://github.com/jskeet/dotnet-protobufs/
4// Original C++/Java/Python code:
Jon Skeet68036862008-10-22 13:30:34 +01005// http://code.google.com/p/protobuf/
6//
Jon Skeet60c059b2008-10-23 21:17:56 +01007// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are
9// met:
Jon Skeet68036862008-10-22 13:30:34 +010010//
Jon Skeet60c059b2008-10-23 21:17:56 +010011// * Redistributions of source code must retain the above copyright
12// notice, this list of conditions and the following disclaimer.
13// * Redistributions in binary form must reproduce the above
14// copyright notice, this list of conditions and the following disclaimer
15// in the documentation and/or other materials provided with the
16// distribution.
17// * Neither the name of Google Inc. nor the names of its
18// contributors may be used to endorse or promote products derived from
19// this software without specific prior written permission.
Jon Skeet68036862008-10-22 13:30:34 +010020//
Jon Skeet60c059b2008-10-23 21:17:56 +010021// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Jon Skeet68036862008-10-22 13:30:34 +010032using System;
33using System.Collections.Generic;
34using System.IO;
35using System.Text;
36using Google.ProtocolBuffers.Descriptors;
37
38namespace Google.ProtocolBuffers {
39
40 /// <summary>
41 /// Readings and decodes protocol message fields.
42 /// </summary>
43 /// <remarks>
44 /// This class contains two kinds of methods: methods that read specific
45 /// protocol message constructs and field types (e.g. ReadTag and
46 /// ReadInt32) and methods that read low-level values (e.g.
47 /// ReadRawVarint32 and ReadRawBytes). If you are reading encoded protocol
48 /// messages, you should use the former methods, but if you are reading some
49 /// other format of your own design, use the latter. The names of the former
50 /// methods are taken from the protocol buffer type names, not .NET types.
51 /// (Hence ReadFloat instead of ReadSingle, and ReadBool instead of ReadBoolean.)
52 ///
53 /// TODO(jonskeet): Consider whether recursion and size limits shouldn't be readonly,
54 /// set at construction time.
55 /// </remarks>
56 public sealed class CodedInputStream {
57 private readonly byte[] buffer;
58 private int bufferSize;
59 private int bufferSizeAfterLimit = 0;
60 private int bufferPos = 0;
61 private readonly Stream input;
62 private uint lastTag = 0;
63
64 const int DefaultRecursionLimit = 64;
65 const int DefaultSizeLimit = 64 << 20; // 64MB
66 const int BufferSize = 4096;
67
68 /// <summary>
69 /// The total number of bytes read before the current buffer. The
70 /// total bytes read up to the current position can be computed as
71 /// totalBytesRetired + bufferPos.
72 /// </summary>
73 private int totalBytesRetired = 0;
74
75 /// <summary>
76 /// The absolute position of the end of the current message.
77 /// </summary>
78 private int currentLimit = int.MaxValue;
79
80 /// <summary>
81 /// <see cref="SetRecursionLimit"/>
82 /// </summary>
83 private int recursionDepth = 0;
84 private int recursionLimit = DefaultRecursionLimit;
85
86 /// <summary>
87 /// <see cref="SetSizeLimit"/>
88 /// </summary>
89 private int sizeLimit = DefaultSizeLimit;
90
91 #region Construction
92 /// <summary>
93 /// Creates a new CodedInputStream reading data from the given
94 /// stream.
95 /// </summary>
96 public static CodedInputStream CreateInstance(Stream input) {
97 return new CodedInputStream(input);
98 }
99
100 /// <summary>
101 /// Creates a new CodedInputStream reading data from the given
102 /// byte array.
103 /// </summary>
104 public static CodedInputStream CreateInstance(byte[] buf) {
105 return new CodedInputStream(buf);
106 }
107
108 private CodedInputStream(byte[] buffer) {
109 this.buffer = buffer;
110 this.bufferSize = buffer.Length;
111 this.input = null;
112 }
113
114 private CodedInputStream(Stream input) {
115 this.buffer = new byte[BufferSize];
116 this.bufferSize = 0;
117 this.input = input;
118 }
119 #endregion
120
121 #region Validation
122 /// <summary>
123 /// Verifies that the last call to ReadTag() returned the given tag value.
124 /// This is used to verify that a nested group ended with the correct
125 /// end tag.
126 /// </summary>
127 /// <exception cref="InvalidProtocolBufferException">The last
128 /// tag read was not the one specified</exception>
129 public void CheckLastTagWas(uint value) {
130 if (lastTag != value) {
131 throw InvalidProtocolBufferException.InvalidEndTag();
132 }
133 }
134 #endregion
135
136 #region Reading of tags etc
137 /// <summary>
138 /// Attempt to read a field tag, returning 0 if we have reached the end
139 /// of the input data. Protocol message parsers use this to read tags,
140 /// since a protocol message may legally end wherever a tag occurs, and
141 /// zero is not a valid tag number.
142 /// </summary>
143 public uint ReadTag() {
Jon Skeet2e6dc122009-05-29 06:34:52 +0100144 if (IsAtEnd) {
Jon Skeet68036862008-10-22 13:30:34 +0100145 lastTag = 0;
146 return 0;
147 }
148
149 lastTag = ReadRawVarint32();
150 if (lastTag == 0) {
151 // If we actually read zero, that's not a valid tag.
152 throw InvalidProtocolBufferException.InvalidTag();
153 }
154 return lastTag;
155 }
156
157 /// <summary>
158 /// Read a double field from the stream.
159 /// </summary>
160 public double ReadDouble() {
161 // TODO(jonskeet): Test this on different endiannesses
162 return BitConverter.Int64BitsToDouble((long) ReadRawLittleEndian64());
163 }
164
165 /// <summary>
166 /// Read a float field from the stream.
167 /// </summary>
168 public float ReadFloat() {
169 // TODO(jonskeet): Test this on different endiannesses
170 uint raw = ReadRawLittleEndian32();
171 byte[] rawBytes = BitConverter.GetBytes(raw);
172 return BitConverter.ToSingle(rawBytes, 0);
173 }
174
175 /// <summary>
176 /// Read a uint64 field from the stream.
177 /// </summary>
178 public ulong ReadUInt64() {
179 return ReadRawVarint64();
180 }
181
182 /// <summary>
183 /// Read an int64 field from the stream.
184 /// </summary>
185 public long ReadInt64() {
186 return (long) ReadRawVarint64();
187 }
188
189 /// <summary>
190 /// Read an int32 field from the stream.
191 /// </summary>
192 public int ReadInt32() {
193 return (int) ReadRawVarint32();
194 }
195
196 /// <summary>
197 /// Read a fixed64 field from the stream.
198 /// </summary>
199 public ulong ReadFixed64() {
200 return ReadRawLittleEndian64();
201 }
202
203 /// <summary>
204 /// Read a fixed32 field from the stream.
205 /// </summary>
206 public uint ReadFixed32() {
207 return ReadRawLittleEndian32();
208 }
209
210 /// <summary>
211 /// Read a bool field from the stream.
212 /// </summary>
213 public bool ReadBool() {
214 return ReadRawVarint32() != 0;
215 }
216
217 /// <summary>
218 /// Reads a string field from the stream.
219 /// </summary>
220 public String ReadString() {
221 int size = (int) ReadRawVarint32();
Jon Skeet6a60ac32009-01-27 14:47:35 +0000222 // No need to read any data for an empty string.
223 if (size == 0) {
224 return "";
225 }
226 if (size <= bufferSize - bufferPos) {
Jon Skeet68036862008-10-22 13:30:34 +0100227 // Fast path: We already have the bytes in a contiguous buffer, so
228 // just copy directly from it.
229 String result = Encoding.UTF8.GetString(buffer, bufferPos, size);
230 bufferPos += size;
231 return result;
Jon Skeet68036862008-10-22 13:30:34 +0100232 }
Jon Skeet6a60ac32009-01-27 14:47:35 +0000233 // Slow path: Build a byte array first then copy it.
234 return Encoding.UTF8.GetString(ReadRawBytes(size));
Jon Skeet68036862008-10-22 13:30:34 +0100235 }
236
237 /// <summary>
238 /// Reads a group field value from the stream.
239 /// </summary>
240 public void ReadGroup(int fieldNumber, IBuilder builder,
241 ExtensionRegistry extensionRegistry) {
242 if (recursionDepth >= recursionLimit) {
243 throw InvalidProtocolBufferException.RecursionLimitExceeded();
244 }
245 ++recursionDepth;
246 builder.WeakMergeFrom(this, extensionRegistry);
247 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
248 --recursionDepth;
249 }
250
251 /// <summary>
252 /// Reads a group field value from the stream and merges it into the given
253 /// UnknownFieldSet.
254 /// </summary>
255 public void ReadUnknownGroup(int fieldNumber, UnknownFieldSet.Builder builder) {
256 if (recursionDepth >= recursionLimit) {
257 throw InvalidProtocolBufferException.RecursionLimitExceeded();
258 }
259 ++recursionDepth;
260 builder.MergeFrom(this);
261 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
262 --recursionDepth;
263 }
264
265 /// <summary>
266 /// Reads an embedded message field value from the stream.
267 /// </summary>
268 public void ReadMessage(IBuilder builder, ExtensionRegistry extensionRegistry) {
269 int length = (int) ReadRawVarint32();
270 if (recursionDepth >= recursionLimit) {
271 throw InvalidProtocolBufferException.RecursionLimitExceeded();
272 }
273 int oldLimit = PushLimit(length);
274 ++recursionDepth;
275 builder.WeakMergeFrom(this, extensionRegistry);
276 CheckLastTagWas(0);
277 --recursionDepth;
278 PopLimit(oldLimit);
279 }
280
281 /// <summary>
282 /// Reads a bytes field value from the stream.
283 /// </summary>
284 public ByteString ReadBytes() {
285 int size = (int) ReadRawVarint32();
286 if (size < bufferSize - bufferPos && size > 0) {
287 // Fast path: We already have the bytes in a contiguous buffer, so
288 // just copy directly from it.
289 ByteString result = ByteString.CopyFrom(buffer, bufferPos, size);
290 bufferPos += size;
291 return result;
292 } else {
293 // Slow path: Build a byte array first then copy it.
294 return ByteString.CopyFrom(ReadRawBytes(size));
295 }
296 }
297
298 /// <summary>
299 /// Reads a uint32 field value from the stream.
300 /// </summary>
301 public uint ReadUInt32() {
302 return ReadRawVarint32();
303 }
304
305 /// <summary>
306 /// Reads an enum field value from the stream. The caller is responsible
307 /// for converting the numeric value to an actual enum.
308 /// </summary>
309 public int ReadEnum() {
310 return (int) ReadRawVarint32();
311 }
312
313 /// <summary>
314 /// Reads an sfixed32 field value from the stream.
315 /// </summary>
316 public int ReadSFixed32() {
317 return (int) ReadRawLittleEndian32();
318 }
319
320 /// <summary>
321 /// Reads an sfixed64 field value from the stream.
322 /// </summary>
323 public long ReadSFixed64() {
324 return (long) ReadRawLittleEndian64();
325 }
326
327 /// <summary>
328 /// Reads an sint32 field value from the stream.
329 /// </summary>
330 public int ReadSInt32() {
331 return DecodeZigZag32(ReadRawVarint32());
332 }
333
334 /// <summary>
335 /// Reads an sint64 field value from the stream.
336 /// </summary>
337 public long ReadSInt64() {
338 return DecodeZigZag64(ReadRawVarint64());
339 }
340
341 /// <summary>
342 /// Reads a field of any primitive type. Enums, groups and embedded
343 /// messages are not handled by this method.
344 /// </summary>
345 public object ReadPrimitiveField(FieldType fieldType) {
346 switch (fieldType) {
347 case FieldType.Double: return ReadDouble();
348 case FieldType.Float: return ReadFloat();
349 case FieldType.Int64: return ReadInt64();
350 case FieldType.UInt64: return ReadUInt64();
351 case FieldType.Int32: return ReadInt32();
352 case FieldType.Fixed64: return ReadFixed64();
353 case FieldType.Fixed32: return ReadFixed32();
354 case FieldType.Bool: return ReadBool();
355 case FieldType.String: return ReadString();
356 case FieldType.Bytes: return ReadBytes();
357 case FieldType.UInt32: return ReadUInt32();
358 case FieldType.SFixed32: return ReadSFixed32();
359 case FieldType.SFixed64: return ReadSFixed64();
360 case FieldType.SInt32: return ReadSInt32();
361 case FieldType.SInt64: return ReadSInt64();
362 case FieldType.Group:
363 throw new ArgumentException("ReadPrimitiveField() cannot handle nested groups.");
364 case FieldType.Message:
365 throw new ArgumentException("ReadPrimitiveField() cannot handle embedded messages.");
366 // We don't handle enums because we don't know what to do if the
367 // value is not recognized.
368 case FieldType.Enum:
369 throw new ArgumentException("ReadPrimitiveField() cannot handle enums.");
370 default:
371 throw new ArgumentOutOfRangeException("Invalid field type " + fieldType);
372 }
373 }
374
375 #endregion
376
377 #region Underlying reading primitives
378
379 /// <summary>
380 /// Same code as ReadRawVarint32, but read each byte individually, checking for
381 /// buffer overflow.
382 /// </summary>
383 private uint SlowReadRawVarint32() {
384 int tmp = ReadRawByte();
385 if (tmp < 128) {
386 return (uint)tmp;
387 }
388 int result = tmp & 0x7f;
389 if ((tmp = ReadRawByte()) < 128) {
390 result |= tmp << 7;
391 } else {
392 result |= (tmp & 0x7f) << 7;
393 if ((tmp = ReadRawByte()) < 128) {
394 result |= tmp << 14;
395 } else {
396 result |= (tmp & 0x7f) << 14;
397 if ((tmp = ReadRawByte()) < 128) {
398 result |= tmp << 21;
399 } else {
400 result |= (tmp & 0x7f) << 21;
401 result |= (tmp = ReadRawByte()) << 28;
402 if (tmp >= 128) {
403 // Discard upper 32 bits.
404 for (int i = 0; i < 5; i++) {
405 if (ReadRawByte() < 128) return (uint)result;
406 }
407 throw InvalidProtocolBufferException.MalformedVarint();
408 }
409 }
410 }
411 }
412 return (uint)result;
413 }
414
415 /// <summary>
416 /// Read a raw Varint from the stream. If larger than 32 bits, discard the upper bits.
417 /// This method is optimised for the case where we've got lots of data in the buffer.
418 /// That means we can check the size just once, then just read directly from the buffer
419 /// without constant rechecking of the buffer length.
420 /// </summary>
421 public uint ReadRawVarint32() {
422 if (bufferPos + 5 > bufferSize) {
423 return SlowReadRawVarint32();
424 }
425
426 int tmp = buffer[bufferPos++];
427 if (tmp < 128) {
428 return (uint)tmp;
429 }
430 int result = tmp & 0x7f;
431 if ((tmp = buffer[bufferPos++]) < 128) {
432 result |= tmp << 7;
433 } else {
434 result |= (tmp & 0x7f) << 7;
435 if ((tmp = buffer[bufferPos++]) < 128) {
436 result |= tmp << 14;
437 } else {
438 result |= (tmp & 0x7f) << 14;
439 if ((tmp = buffer[bufferPos++]) < 128) {
440 result |= tmp << 21;
441 } else {
442 result |= (tmp & 0x7f) << 21;
443 result |= (tmp = buffer[bufferPos++]) << 28;
444 if (tmp >= 128) {
445 // Discard upper 32 bits.
446 // Note that this has to use ReadRawByte() as we only ensure we've
447 // got at least 5 bytes at the start of the method. This lets us
448 // use the fast path in more cases, and we rarely hit this section of code.
449 for (int i = 0; i < 5; i++) {
450 if (ReadRawByte() < 128) return (uint)result;
451 }
452 throw InvalidProtocolBufferException.MalformedVarint();
453 }
454 }
455 }
456 }
457 return (uint)result;
458 }
459
460 /// <summary>
Jon Skeet2e6dc122009-05-29 06:34:52 +0100461 /// Reads a varint from the input one byte at a time, so that it does not
462 /// read any bytes after the end of the varint. If you simply wrapped the
463 /// stream in a CodedInputStream and used ReadRawVarint32(Stream)}
464 /// then you would probably end up reading past the end of the varint since
465 /// CodedInputStream buffers its input.
466 /// </summary>
467 /// <param name="input"></param>
468 /// <returns></returns>
469 internal static int ReadRawVarint32(Stream input) {
470 int result = 0;
471 int offset = 0;
472 for (; offset < 32; offset += 7) {
473 int b = input.ReadByte();
474 if (b == -1) {
475 throw InvalidProtocolBufferException.TruncatedMessage();
476 }
477 result |= (b & 0x7f) << offset;
478 if ((b & 0x80) == 0) {
479 return result;
480 }
481 }
482 // Keep reading up to 64 bits.
483 for (; offset < 64; offset += 7) {
484 int b = input.ReadByte();
485 if (b == -1) {
486 throw InvalidProtocolBufferException.TruncatedMessage();
487 }
488 if ((b & 0x80) == 0) {
489 return result;
490 }
491 }
492 throw InvalidProtocolBufferException.MalformedVarint();
493 }
494
495 /// <summary>
Jon Skeet68036862008-10-22 13:30:34 +0100496 /// Read a raw varint from the stream.
497 /// </summary>
498 public ulong ReadRawVarint64() {
499 int shift = 0;
500 ulong result = 0;
501 while (shift < 64) {
502 byte b = ReadRawByte();
503 result |= (ulong)(b & 0x7F) << shift;
504 if ((b & 0x80) == 0) {
505 return result;
506 }
507 shift += 7;
508 }
509 throw InvalidProtocolBufferException.MalformedVarint();
510 }
511
512 /// <summary>
513 /// Read a 32-bit little-endian integer from the stream.
514 /// </summary>
515 public uint ReadRawLittleEndian32() {
516 uint b1 = ReadRawByte();
517 uint b2 = ReadRawByte();
518 uint b3 = ReadRawByte();
519 uint b4 = ReadRawByte();
520 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24);
521 }
522
523 /// <summary>
524 /// Read a 64-bit little-endian integer from the stream.
525 /// </summary>
526 public ulong ReadRawLittleEndian64() {
527 ulong b1 = ReadRawByte();
528 ulong b2 = ReadRawByte();
529 ulong b3 = ReadRawByte();
530 ulong b4 = ReadRawByte();
531 ulong b5 = ReadRawByte();
532 ulong b6 = ReadRawByte();
533 ulong b7 = ReadRawByte();
534 ulong b8 = ReadRawByte();
535 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24)
536 | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56);
537 }
538 #endregion
539
540 /// <summary>
541 /// Decode a 32-bit value with ZigZag encoding.
542 /// </summary>
543 /// <remarks>
544 /// ZigZag encodes signed integers into values that can be efficiently
545 /// encoded with varint. (Otherwise, negative values must be
546 /// sign-extended to 64 bits to be varint encoded, thus always taking
547 /// 10 bytes on the wire.)
548 /// </remarks>
549 public static int DecodeZigZag32(uint n) {
550 return (int)(n >> 1) ^ -(int)(n & 1);
551 }
552
553 /// <summary>
554 /// Decode a 32-bit value with ZigZag encoding.
555 /// </summary>
556 /// <remarks>
557 /// ZigZag encodes signed integers into values that can be efficiently
558 /// encoded with varint. (Otherwise, negative values must be
559 /// sign-extended to 64 bits to be varint encoded, thus always taking
560 /// 10 bytes on the wire.)
561 /// </remarks>
562 public static long DecodeZigZag64(ulong n) {
563 return (long)(n >> 1) ^ -(long)(n & 1);
564 }
565
566 /// <summary>
567 /// Set the maximum message recursion depth.
568 /// </summary>
569 /// <remarks>
570 /// In order to prevent malicious
571 /// messages from causing stack overflows, CodedInputStream limits
572 /// how deeply messages may be nested. The default limit is 64.
573 /// </remarks>
574 public int SetRecursionLimit(int limit) {
575 if (limit < 0) {
576 throw new ArgumentOutOfRangeException("Recursion limit cannot be negative: " + limit);
577 }
578 int oldLimit = recursionLimit;
579 recursionLimit = limit;
580 return oldLimit;
581 }
582
583 /// <summary>
584 /// Set the maximum message size.
585 /// </summary>
586 /// <remarks>
587 /// In order to prevent malicious messages from exhausting memory or
588 /// causing integer overflows, CodedInputStream limits how large a message may be.
589 /// The default limit is 64MB. You should set this limit as small
590 /// as you can without harming your app's functionality. Note that
591 /// size limits only apply when reading from an InputStream, not
592 /// when constructed around a raw byte array (nor with ByteString.NewCodedInput).
Jon Skeet2e6dc122009-05-29 06:34:52 +0100593 /// If you want to read several messages from a single CodedInputStream, you
594 /// can call ResetSizeCounter() after each message to avoid hitting the
595 /// size limit.
Jon Skeet68036862008-10-22 13:30:34 +0100596 /// </remarks>
597 public int SetSizeLimit(int limit) {
598 if (limit < 0) {
599 throw new ArgumentOutOfRangeException("Size limit cannot be negative: " + limit);
600 }
601 int oldLimit = sizeLimit;
602 sizeLimit = limit;
603 return oldLimit;
604 }
605
606 #region Internal reading and buffer management
607 /// <summary>
Jon Skeet2e6dc122009-05-29 06:34:52 +0100608 /// Resets the current size counter to zero (see SetSizeLimit).
609 /// </summary>
610 public void ResetSizeCounter() {
611 totalBytesRetired = 0;
612 }
613
614 /// <summary>
Jon Skeet68036862008-10-22 13:30:34 +0100615 /// Sets currentLimit to (current position) + byteLimit. This is called
616 /// when descending into a length-delimited embedded message. The previous
617 /// limit is returned.
618 /// </summary>
619 /// <returns>The old limit.</returns>
620 public int PushLimit(int byteLimit) {
621 if (byteLimit < 0) {
622 throw InvalidProtocolBufferException.NegativeSize();
623 }
624 byteLimit += totalBytesRetired + bufferPos;
625 int oldLimit = currentLimit;
626 if (byteLimit > oldLimit) {
627 throw InvalidProtocolBufferException.TruncatedMessage();
628 }
629 currentLimit = byteLimit;
630
631 RecomputeBufferSizeAfterLimit();
632
633 return oldLimit;
634 }
635
636 private void RecomputeBufferSizeAfterLimit() {
637 bufferSize += bufferSizeAfterLimit;
638 int bufferEnd = totalBytesRetired + bufferSize;
639 if (bufferEnd > currentLimit) {
640 // Limit is in current buffer.
641 bufferSizeAfterLimit = bufferEnd - currentLimit;
642 bufferSize -= bufferSizeAfterLimit;
643 } else {
644 bufferSizeAfterLimit = 0;
645 }
646 }
647
648 /// <summary>
649 /// Discards the current limit, returning the previous limit.
650 /// </summary>
651 public void PopLimit(int oldLimit) {
652 currentLimit = oldLimit;
653 RecomputeBufferSizeAfterLimit();
654 }
655
656 /// <summary>
Jon Skeet25a28582009-02-18 16:06:22 +0000657 /// Returns whether or not all the data before the limit has been read.
658 /// </summary>
659 /// <returns></returns>
660 public bool ReachedLimit {
661 get {
662 if (currentLimit == int.MaxValue) {
663 return false;
664 }
665 int currentAbsolutePosition = totalBytesRetired + bufferPos;
666 return currentAbsolutePosition >= currentLimit;
667 }
668 }
Jon Skeet2e6dc122009-05-29 06:34:52 +0100669
670 /// <summary>
671 /// Returns true if the stream has reached the end of the input. This is the
672 /// case if either the end of the underlying input source has been reached or
673 /// the stream has reached a limit created using PushLimit.
674 /// </summary>
675 public bool IsAtEnd {
676 get {
677 return bufferPos == bufferSize && !RefillBuffer(false);
678 }
679 }
Jon Skeet25a28582009-02-18 16:06:22 +0000680
681 /// <summary>
Jon Skeet68036862008-10-22 13:30:34 +0100682 /// Called when buffer is empty to read more bytes from the
683 /// input. If <paramref name="mustSucceed"/> is true, RefillBuffer() gurantees that
684 /// either there will be at least one byte in the buffer when it returns
685 /// or it will throw an exception. If <paramref name="mustSucceed"/> is false,
686 /// RefillBuffer() returns false if no more bytes were available.
687 /// </summary>
688 /// <param name="mustSucceed"></param>
689 /// <returns></returns>
690 private bool RefillBuffer(bool mustSucceed) {
691 if (bufferPos < bufferSize) {
692 throw new InvalidOperationException("RefillBuffer() called when buffer wasn't empty.");
693 }
694
695 if (totalBytesRetired + bufferSize == currentLimit) {
696 // Oops, we hit a limit.
697 if (mustSucceed) {
698 throw InvalidProtocolBufferException.TruncatedMessage();
699 } else {
700 return false;
701 }
702 }
703
704 totalBytesRetired += bufferSize;
705
706 bufferPos = 0;
707 bufferSize = (input == null) ? 0 : input.Read(buffer, 0, buffer.Length);
Jon Skeet2e6dc122009-05-29 06:34:52 +0100708 if (bufferSize < 0) {
709 throw new InvalidOperationException("Stream.Read returned a negative count");
710 }
Jon Skeet68036862008-10-22 13:30:34 +0100711 if (bufferSize == 0) {
712 if (mustSucceed) {
713 throw InvalidProtocolBufferException.TruncatedMessage();
714 } else {
715 return false;
716 }
717 } else {
718 RecomputeBufferSizeAfterLimit();
719 int totalBytesRead =
720 totalBytesRetired + bufferSize + bufferSizeAfterLimit;
721 if (totalBytesRead > sizeLimit || totalBytesRead < 0) {
722 throw InvalidProtocolBufferException.SizeLimitExceeded();
723 }
724 return true;
725 }
726 }
727
728 /// <summary>
729 /// Read one byte from the input.
730 /// </summary>
731 /// <exception cref="InvalidProtocolBufferException">
732 /// he end of the stream or the current limit was reached
733 /// </exception>
734 public byte ReadRawByte() {
735 if (bufferPos == bufferSize) {
736 RefillBuffer(true);
737 }
738 return buffer[bufferPos++];
739 }
740
741 /// <summary>
742 /// Read a fixed size of bytes from the input.
743 /// </summary>
744 /// <exception cref="InvalidProtocolBufferException">
745 /// the end of the stream or the current limit was reached
746 /// </exception>
747 public byte[] ReadRawBytes(int size) {
748 if (size < 0) {
749 throw InvalidProtocolBufferException.NegativeSize();
750 }
751
752 if (totalBytesRetired + bufferPos + size > currentLimit) {
753 // Read to the end of the stream anyway.
754 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
755 // Then fail.
756 throw InvalidProtocolBufferException.TruncatedMessage();
757 }
758
759 if (size <= bufferSize - bufferPos) {
760 // We have all the bytes we need already.
761 byte[] bytes = new byte[size];
762 Array.Copy(buffer, bufferPos, bytes, 0, size);
763 bufferPos += size;
764 return bytes;
765 } else if (size < BufferSize) {
766 // Reading more bytes than are in the buffer, but not an excessive number
767 // of bytes. We can safely allocate the resulting array ahead of time.
768
769 // First copy what we have.
770 byte[] bytes = new byte[size];
771 int pos = bufferSize - bufferPos;
772 Array.Copy(buffer, bufferPos, bytes, 0, pos);
773 bufferPos = bufferSize;
774
775 // We want to use RefillBuffer() and then copy from the buffer into our
776 // byte array rather than reading directly into our byte array because
777 // the input may be unbuffered.
778 RefillBuffer(true);
779
780 while (size - pos > bufferSize) {
781 Array.Copy(buffer, 0, bytes, pos, bufferSize);
782 pos += bufferSize;
783 bufferPos = bufferSize;
784 RefillBuffer(true);
785 }
786
787 Array.Copy(buffer, 0, bytes, pos, size - pos);
788 bufferPos = size - pos;
789
790 return bytes;
791 } else {
792 // The size is very large. For security reasons, we can't allocate the
793 // entire byte array yet. The size comes directly from the input, so a
794 // maliciously-crafted message could provide a bogus very large size in
795 // order to trick the app into allocating a lot of memory. We avoid this
796 // by allocating and reading only a small chunk at a time, so that the
797 // malicious message must actually *be* extremely large to cause
798 // problems. Meanwhile, we limit the allowed size of a message elsewhere.
799
800 // Remember the buffer markers since we'll have to copy the bytes out of
801 // it later.
802 int originalBufferPos = bufferPos;
803 int originalBufferSize = bufferSize;
804
805 // Mark the current buffer consumed.
806 totalBytesRetired += bufferSize;
807 bufferPos = 0;
808 bufferSize = 0;
809
810 // Read all the rest of the bytes we need.
811 int sizeLeft = size - (originalBufferSize - originalBufferPos);
812 List<byte[]> chunks = new List<byte[]>();
813
814 while (sizeLeft > 0) {
815 byte[] chunk = new byte[Math.Min(sizeLeft, BufferSize)];
816 int pos = 0;
817 while (pos < chunk.Length) {
818 int n = (input == null) ? -1 : input.Read(chunk, pos, chunk.Length - pos);
819 if (n <= 0) {
820 throw InvalidProtocolBufferException.TruncatedMessage();
821 }
822 totalBytesRetired += n;
823 pos += n;
824 }
825 sizeLeft -= chunk.Length;
826 chunks.Add(chunk);
827 }
828
829 // OK, got everything. Now concatenate it all into one buffer.
830 byte[] bytes = new byte[size];
831
832 // Start by copying the leftover bytes from this.buffer.
833 int newPos = originalBufferSize - originalBufferPos;
834 Array.Copy(buffer, originalBufferPos, bytes, 0, newPos);
835
836 // And now all the chunks.
837 foreach (byte[] chunk in chunks) {
838 Array.Copy(chunk, 0, bytes, newPos, chunk.Length);
839 newPos += chunk.Length;
840 }
841
842 // Done.
843 return bytes;
844 }
845 }
846
847 /// <summary>
848 /// Reads and discards a single field, given its tag value.
849 /// </summary>
850 /// <returns>false if the tag is an end-group tag, in which case
851 /// nothing is skipped. Otherwise, returns true.</returns>
852 public bool SkipField(uint tag) {
853 switch (WireFormat.GetTagWireType(tag)) {
854 case WireFormat.WireType.Varint:
855 ReadInt32();
856 return true;
857 case WireFormat.WireType.Fixed64:
858 ReadRawLittleEndian64();
859 return true;
860 case WireFormat.WireType.LengthDelimited:
861 SkipRawBytes((int) ReadRawVarint32());
862 return true;
863 case WireFormat.WireType.StartGroup:
864 SkipMessage();
865 CheckLastTagWas(
866 WireFormat.MakeTag(WireFormat.GetTagFieldNumber(tag),
867 WireFormat.WireType.EndGroup));
868 return true;
869 case WireFormat.WireType.EndGroup:
870 return false;
871 case WireFormat.WireType.Fixed32:
872 ReadRawLittleEndian32();
873 return true;
874 default:
875 throw InvalidProtocolBufferException.InvalidWireType();
876 }
877 }
878
879 /// <summary>
880 /// Reads and discards an entire message. This will read either until EOF
881 /// or until an endgroup tag, whichever comes first.
882 /// </summary>
883 public void SkipMessage() {
884 while (true) {
885 uint tag = ReadTag();
886 if (tag == 0 || !SkipField(tag)) {
887 return;
888 }
889 }
890 }
891
892 /// <summary>
893 /// Reads and discards <paramref name="size"/> bytes.
894 /// </summary>
895 /// <exception cref="InvalidProtocolBufferException">the end of the stream
896 /// or the current limit was reached</exception>
897 public void SkipRawBytes(int size) {
898 if (size < 0) {
899 throw InvalidProtocolBufferException.NegativeSize();
900 }
901
902 if (totalBytesRetired + bufferPos + size > currentLimit) {
903 // Read to the end of the stream anyway.
904 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
905 // Then fail.
906 throw InvalidProtocolBufferException.TruncatedMessage();
907 }
908
Jon Skeet2e6dc122009-05-29 06:34:52 +0100909 if (size <= bufferSize - bufferPos) {
Jon Skeet68036862008-10-22 13:30:34 +0100910 // We have all the bytes we need already.
911 bufferPos += size;
912 } else {
913 // Skipping more bytes than are in the buffer. First skip what we have.
914 int pos = bufferSize - bufferPos;
915 totalBytesRetired += pos;
916 bufferPos = 0;
917 bufferSize = 0;
918
919 // Then skip directly from the InputStream for the rest.
920 if (pos < size) {
921 // TODO(jonskeet): Java implementation uses skip(). Not sure whether this is really equivalent...
922 if (input == null) {
923 throw InvalidProtocolBufferException.TruncatedMessage();
924 }
Jon Skeet2e6dc122009-05-29 06:34:52 +0100925 long previousPosition = input.Position;
926 input.Position += size - pos;
927 if (input.Position != previousPosition + size - pos) {
Jon Skeet68036862008-10-22 13:30:34 +0100928 throw InvalidProtocolBufferException.TruncatedMessage();
929 }
930 totalBytesRetired += size - pos;
931 }
932 }
933 }
934 #endregion
935 }
936}