blob: 30969820ccf2055c8ad80b8d196b278673098b09 [file] [log] [blame]
Jon Skeet68036862008-10-22 13:30:34 +01001// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc.
3// http://code.google.com/p/protobuf/
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16using System;
17using System.Collections.Generic;
18using System.IO;
19using System.Text;
20using Google.ProtocolBuffers.Descriptors;
21
22namespace Google.ProtocolBuffers {
23
24 /// <summary>
25 /// Readings and decodes protocol message fields.
26 /// </summary>
27 /// <remarks>
28 /// This class contains two kinds of methods: methods that read specific
29 /// protocol message constructs and field types (e.g. ReadTag and
30 /// ReadInt32) and methods that read low-level values (e.g.
31 /// ReadRawVarint32 and ReadRawBytes). If you are reading encoded protocol
32 /// messages, you should use the former methods, but if you are reading some
33 /// other format of your own design, use the latter. The names of the former
34 /// methods are taken from the protocol buffer type names, not .NET types.
35 /// (Hence ReadFloat instead of ReadSingle, and ReadBool instead of ReadBoolean.)
36 ///
37 /// TODO(jonskeet): Consider whether recursion and size limits shouldn't be readonly,
38 /// set at construction time.
39 /// </remarks>
40 public sealed class CodedInputStream {
41 private readonly byte[] buffer;
42 private int bufferSize;
43 private int bufferSizeAfterLimit = 0;
44 private int bufferPos = 0;
45 private readonly Stream input;
46 private uint lastTag = 0;
47
48 const int DefaultRecursionLimit = 64;
49 const int DefaultSizeLimit = 64 << 20; // 64MB
50 const int BufferSize = 4096;
51
52 /// <summary>
53 /// The total number of bytes read before the current buffer. The
54 /// total bytes read up to the current position can be computed as
55 /// totalBytesRetired + bufferPos.
56 /// </summary>
57 private int totalBytesRetired = 0;
58
59 /// <summary>
60 /// The absolute position of the end of the current message.
61 /// </summary>
62 private int currentLimit = int.MaxValue;
63
64 /// <summary>
65 /// <see cref="SetRecursionLimit"/>
66 /// </summary>
67 private int recursionDepth = 0;
68 private int recursionLimit = DefaultRecursionLimit;
69
70 /// <summary>
71 /// <see cref="SetSizeLimit"/>
72 /// </summary>
73 private int sizeLimit = DefaultSizeLimit;
74
75 #region Construction
76 /// <summary>
77 /// Creates a new CodedInputStream reading data from the given
78 /// stream.
79 /// </summary>
80 public static CodedInputStream CreateInstance(Stream input) {
81 return new CodedInputStream(input);
82 }
83
84 /// <summary>
85 /// Creates a new CodedInputStream reading data from the given
86 /// byte array.
87 /// </summary>
88 public static CodedInputStream CreateInstance(byte[] buf) {
89 return new CodedInputStream(buf);
90 }
91
92 private CodedInputStream(byte[] buffer) {
93 this.buffer = buffer;
94 this.bufferSize = buffer.Length;
95 this.input = null;
96 }
97
98 private CodedInputStream(Stream input) {
99 this.buffer = new byte[BufferSize];
100 this.bufferSize = 0;
101 this.input = input;
102 }
103 #endregion
104
105 #region Validation
106 /// <summary>
107 /// Verifies that the last call to ReadTag() returned the given tag value.
108 /// This is used to verify that a nested group ended with the correct
109 /// end tag.
110 /// </summary>
111 /// <exception cref="InvalidProtocolBufferException">The last
112 /// tag read was not the one specified</exception>
113 public void CheckLastTagWas(uint value) {
114 if (lastTag != value) {
115 throw InvalidProtocolBufferException.InvalidEndTag();
116 }
117 }
118 #endregion
119
120 #region Reading of tags etc
121 /// <summary>
122 /// Attempt to read a field tag, returning 0 if we have reached the end
123 /// of the input data. Protocol message parsers use this to read tags,
124 /// since a protocol message may legally end wherever a tag occurs, and
125 /// zero is not a valid tag number.
126 /// </summary>
127 public uint ReadTag() {
128 if (bufferPos == bufferSize && !RefillBuffer(false)) {
129 lastTag = 0;
130 return 0;
131 }
132
133 lastTag = ReadRawVarint32();
134 if (lastTag == 0) {
135 // If we actually read zero, that's not a valid tag.
136 throw InvalidProtocolBufferException.InvalidTag();
137 }
138 return lastTag;
139 }
140
141 /// <summary>
142 /// Read a double field from the stream.
143 /// </summary>
144 public double ReadDouble() {
145 // TODO(jonskeet): Test this on different endiannesses
146 return BitConverter.Int64BitsToDouble((long) ReadRawLittleEndian64());
147 }
148
149 /// <summary>
150 /// Read a float field from the stream.
151 /// </summary>
152 public float ReadFloat() {
153 // TODO(jonskeet): Test this on different endiannesses
154 uint raw = ReadRawLittleEndian32();
155 byte[] rawBytes = BitConverter.GetBytes(raw);
156 return BitConverter.ToSingle(rawBytes, 0);
157 }
158
159 /// <summary>
160 /// Read a uint64 field from the stream.
161 /// </summary>
162 public ulong ReadUInt64() {
163 return ReadRawVarint64();
164 }
165
166 /// <summary>
167 /// Read an int64 field from the stream.
168 /// </summary>
169 public long ReadInt64() {
170 return (long) ReadRawVarint64();
171 }
172
173 /// <summary>
174 /// Read an int32 field from the stream.
175 /// </summary>
176 public int ReadInt32() {
177 return (int) ReadRawVarint32();
178 }
179
180 /// <summary>
181 /// Read a fixed64 field from the stream.
182 /// </summary>
183 public ulong ReadFixed64() {
184 return ReadRawLittleEndian64();
185 }
186
187 /// <summary>
188 /// Read a fixed32 field from the stream.
189 /// </summary>
190 public uint ReadFixed32() {
191 return ReadRawLittleEndian32();
192 }
193
194 /// <summary>
195 /// Read a bool field from the stream.
196 /// </summary>
197 public bool ReadBool() {
198 return ReadRawVarint32() != 0;
199 }
200
201 /// <summary>
202 /// Reads a string field from the stream.
203 /// </summary>
204 public String ReadString() {
205 int size = (int) ReadRawVarint32();
206 if (size < bufferSize - bufferPos && size > 0) {
207 // Fast path: We already have the bytes in a contiguous buffer, so
208 // just copy directly from it.
209 String result = Encoding.UTF8.GetString(buffer, bufferPos, size);
210 bufferPos += size;
211 return result;
212 } else {
213 // Slow path: Build a byte array first then copy it.
214 return Encoding.UTF8.GetString(ReadRawBytes(size));
215 }
216 }
217
218 /// <summary>
219 /// Reads a group field value from the stream.
220 /// </summary>
221 public void ReadGroup(int fieldNumber, IBuilder builder,
222 ExtensionRegistry extensionRegistry) {
223 if (recursionDepth >= recursionLimit) {
224 throw InvalidProtocolBufferException.RecursionLimitExceeded();
225 }
226 ++recursionDepth;
227 builder.WeakMergeFrom(this, extensionRegistry);
228 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
229 --recursionDepth;
230 }
231
232 /// <summary>
233 /// Reads a group field value from the stream and merges it into the given
234 /// UnknownFieldSet.
235 /// </summary>
236 public void ReadUnknownGroup(int fieldNumber, UnknownFieldSet.Builder builder) {
237 if (recursionDepth >= recursionLimit) {
238 throw InvalidProtocolBufferException.RecursionLimitExceeded();
239 }
240 ++recursionDepth;
241 builder.MergeFrom(this);
242 CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
243 --recursionDepth;
244 }
245
246 /// <summary>
247 /// Reads an embedded message field value from the stream.
248 /// </summary>
249 public void ReadMessage(IBuilder builder, ExtensionRegistry extensionRegistry) {
250 int length = (int) ReadRawVarint32();
251 if (recursionDepth >= recursionLimit) {
252 throw InvalidProtocolBufferException.RecursionLimitExceeded();
253 }
254 int oldLimit = PushLimit(length);
255 ++recursionDepth;
256 builder.WeakMergeFrom(this, extensionRegistry);
257 CheckLastTagWas(0);
258 --recursionDepth;
259 PopLimit(oldLimit);
260 }
261
262 /// <summary>
263 /// Reads a bytes field value from the stream.
264 /// </summary>
265 public ByteString ReadBytes() {
266 int size = (int) ReadRawVarint32();
267 if (size < bufferSize - bufferPos && size > 0) {
268 // Fast path: We already have the bytes in a contiguous buffer, so
269 // just copy directly from it.
270 ByteString result = ByteString.CopyFrom(buffer, bufferPos, size);
271 bufferPos += size;
272 return result;
273 } else {
274 // Slow path: Build a byte array first then copy it.
275 return ByteString.CopyFrom(ReadRawBytes(size));
276 }
277 }
278
279 /// <summary>
280 /// Reads a uint32 field value from the stream.
281 /// </summary>
282 public uint ReadUInt32() {
283 return ReadRawVarint32();
284 }
285
286 /// <summary>
287 /// Reads an enum field value from the stream. The caller is responsible
288 /// for converting the numeric value to an actual enum.
289 /// </summary>
290 public int ReadEnum() {
291 return (int) ReadRawVarint32();
292 }
293
294 /// <summary>
295 /// Reads an sfixed32 field value from the stream.
296 /// </summary>
297 public int ReadSFixed32() {
298 return (int) ReadRawLittleEndian32();
299 }
300
301 /// <summary>
302 /// Reads an sfixed64 field value from the stream.
303 /// </summary>
304 public long ReadSFixed64() {
305 return (long) ReadRawLittleEndian64();
306 }
307
308 /// <summary>
309 /// Reads an sint32 field value from the stream.
310 /// </summary>
311 public int ReadSInt32() {
312 return DecodeZigZag32(ReadRawVarint32());
313 }
314
315 /// <summary>
316 /// Reads an sint64 field value from the stream.
317 /// </summary>
318 public long ReadSInt64() {
319 return DecodeZigZag64(ReadRawVarint64());
320 }
321
322 /// <summary>
323 /// Reads a field of any primitive type. Enums, groups and embedded
324 /// messages are not handled by this method.
325 /// </summary>
326 public object ReadPrimitiveField(FieldType fieldType) {
327 switch (fieldType) {
328 case FieldType.Double: return ReadDouble();
329 case FieldType.Float: return ReadFloat();
330 case FieldType.Int64: return ReadInt64();
331 case FieldType.UInt64: return ReadUInt64();
332 case FieldType.Int32: return ReadInt32();
333 case FieldType.Fixed64: return ReadFixed64();
334 case FieldType.Fixed32: return ReadFixed32();
335 case FieldType.Bool: return ReadBool();
336 case FieldType.String: return ReadString();
337 case FieldType.Bytes: return ReadBytes();
338 case FieldType.UInt32: return ReadUInt32();
339 case FieldType.SFixed32: return ReadSFixed32();
340 case FieldType.SFixed64: return ReadSFixed64();
341 case FieldType.SInt32: return ReadSInt32();
342 case FieldType.SInt64: return ReadSInt64();
343 case FieldType.Group:
344 throw new ArgumentException("ReadPrimitiveField() cannot handle nested groups.");
345 case FieldType.Message:
346 throw new ArgumentException("ReadPrimitiveField() cannot handle embedded messages.");
347 // We don't handle enums because we don't know what to do if the
348 // value is not recognized.
349 case FieldType.Enum:
350 throw new ArgumentException("ReadPrimitiveField() cannot handle enums.");
351 default:
352 throw new ArgumentOutOfRangeException("Invalid field type " + fieldType);
353 }
354 }
355
356 #endregion
357
358 #region Underlying reading primitives
359
360 /// <summary>
361 /// Same code as ReadRawVarint32, but read each byte individually, checking for
362 /// buffer overflow.
363 /// </summary>
364 private uint SlowReadRawVarint32() {
365 int tmp = ReadRawByte();
366 if (tmp < 128) {
367 return (uint)tmp;
368 }
369 int result = tmp & 0x7f;
370 if ((tmp = ReadRawByte()) < 128) {
371 result |= tmp << 7;
372 } else {
373 result |= (tmp & 0x7f) << 7;
374 if ((tmp = ReadRawByte()) < 128) {
375 result |= tmp << 14;
376 } else {
377 result |= (tmp & 0x7f) << 14;
378 if ((tmp = ReadRawByte()) < 128) {
379 result |= tmp << 21;
380 } else {
381 result |= (tmp & 0x7f) << 21;
382 result |= (tmp = ReadRawByte()) << 28;
383 if (tmp >= 128) {
384 // Discard upper 32 bits.
385 for (int i = 0; i < 5; i++) {
386 if (ReadRawByte() < 128) return (uint)result;
387 }
388 throw InvalidProtocolBufferException.MalformedVarint();
389 }
390 }
391 }
392 }
393 return (uint)result;
394 }
395
396 /// <summary>
397 /// Read a raw Varint from the stream. If larger than 32 bits, discard the upper bits.
398 /// This method is optimised for the case where we've got lots of data in the buffer.
399 /// That means we can check the size just once, then just read directly from the buffer
400 /// without constant rechecking of the buffer length.
401 /// </summary>
402 public uint ReadRawVarint32() {
403 if (bufferPos + 5 > bufferSize) {
404 return SlowReadRawVarint32();
405 }
406
407 int tmp = buffer[bufferPos++];
408 if (tmp < 128) {
409 return (uint)tmp;
410 }
411 int result = tmp & 0x7f;
412 if ((tmp = buffer[bufferPos++]) < 128) {
413 result |= tmp << 7;
414 } else {
415 result |= (tmp & 0x7f) << 7;
416 if ((tmp = buffer[bufferPos++]) < 128) {
417 result |= tmp << 14;
418 } else {
419 result |= (tmp & 0x7f) << 14;
420 if ((tmp = buffer[bufferPos++]) < 128) {
421 result |= tmp << 21;
422 } else {
423 result |= (tmp & 0x7f) << 21;
424 result |= (tmp = buffer[bufferPos++]) << 28;
425 if (tmp >= 128) {
426 // Discard upper 32 bits.
427 // Note that this has to use ReadRawByte() as we only ensure we've
428 // got at least 5 bytes at the start of the method. This lets us
429 // use the fast path in more cases, and we rarely hit this section of code.
430 for (int i = 0; i < 5; i++) {
431 if (ReadRawByte() < 128) return (uint)result;
432 }
433 throw InvalidProtocolBufferException.MalformedVarint();
434 }
435 }
436 }
437 }
438 return (uint)result;
439 }
440
441 /// <summary>
442 /// Read a raw varint from the stream.
443 /// </summary>
444 public ulong ReadRawVarint64() {
445 int shift = 0;
446 ulong result = 0;
447 while (shift < 64) {
448 byte b = ReadRawByte();
449 result |= (ulong)(b & 0x7F) << shift;
450 if ((b & 0x80) == 0) {
451 return result;
452 }
453 shift += 7;
454 }
455 throw InvalidProtocolBufferException.MalformedVarint();
456 }
457
458 /// <summary>
459 /// Read a 32-bit little-endian integer from the stream.
460 /// </summary>
461 public uint ReadRawLittleEndian32() {
462 uint b1 = ReadRawByte();
463 uint b2 = ReadRawByte();
464 uint b3 = ReadRawByte();
465 uint b4 = ReadRawByte();
466 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24);
467 }
468
469 /// <summary>
470 /// Read a 64-bit little-endian integer from the stream.
471 /// </summary>
472 public ulong ReadRawLittleEndian64() {
473 ulong b1 = ReadRawByte();
474 ulong b2 = ReadRawByte();
475 ulong b3 = ReadRawByte();
476 ulong b4 = ReadRawByte();
477 ulong b5 = ReadRawByte();
478 ulong b6 = ReadRawByte();
479 ulong b7 = ReadRawByte();
480 ulong b8 = ReadRawByte();
481 return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24)
482 | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56);
483 }
484 #endregion
485
486 /// <summary>
487 /// Decode a 32-bit value with ZigZag encoding.
488 /// </summary>
489 /// <remarks>
490 /// ZigZag encodes signed integers into values that can be efficiently
491 /// encoded with varint. (Otherwise, negative values must be
492 /// sign-extended to 64 bits to be varint encoded, thus always taking
493 /// 10 bytes on the wire.)
494 /// </remarks>
495 public static int DecodeZigZag32(uint n) {
496 return (int)(n >> 1) ^ -(int)(n & 1);
497 }
498
499 /// <summary>
500 /// Decode a 32-bit value with ZigZag encoding.
501 /// </summary>
502 /// <remarks>
503 /// ZigZag encodes signed integers into values that can be efficiently
504 /// encoded with varint. (Otherwise, negative values must be
505 /// sign-extended to 64 bits to be varint encoded, thus always taking
506 /// 10 bytes on the wire.)
507 /// </remarks>
508 public static long DecodeZigZag64(ulong n) {
509 return (long)(n >> 1) ^ -(long)(n & 1);
510 }
511
512 /// <summary>
513 /// Set the maximum message recursion depth.
514 /// </summary>
515 /// <remarks>
516 /// In order to prevent malicious
517 /// messages from causing stack overflows, CodedInputStream limits
518 /// how deeply messages may be nested. The default limit is 64.
519 /// </remarks>
520 public int SetRecursionLimit(int limit) {
521 if (limit < 0) {
522 throw new ArgumentOutOfRangeException("Recursion limit cannot be negative: " + limit);
523 }
524 int oldLimit = recursionLimit;
525 recursionLimit = limit;
526 return oldLimit;
527 }
528
529 /// <summary>
530 /// Set the maximum message size.
531 /// </summary>
532 /// <remarks>
533 /// In order to prevent malicious messages from exhausting memory or
534 /// causing integer overflows, CodedInputStream limits how large a message may be.
535 /// The default limit is 64MB. You should set this limit as small
536 /// as you can without harming your app's functionality. Note that
537 /// size limits only apply when reading from an InputStream, not
538 /// when constructed around a raw byte array (nor with ByteString.NewCodedInput).
539 /// </remarks>
540 public int SetSizeLimit(int limit) {
541 if (limit < 0) {
542 throw new ArgumentOutOfRangeException("Size limit cannot be negative: " + limit);
543 }
544 int oldLimit = sizeLimit;
545 sizeLimit = limit;
546 return oldLimit;
547 }
548
549 #region Internal reading and buffer management
550 /// <summary>
551 /// Sets currentLimit to (current position) + byteLimit. This is called
552 /// when descending into a length-delimited embedded message. The previous
553 /// limit is returned.
554 /// </summary>
555 /// <returns>The old limit.</returns>
556 public int PushLimit(int byteLimit) {
557 if (byteLimit < 0) {
558 throw InvalidProtocolBufferException.NegativeSize();
559 }
560 byteLimit += totalBytesRetired + bufferPos;
561 int oldLimit = currentLimit;
562 if (byteLimit > oldLimit) {
563 throw InvalidProtocolBufferException.TruncatedMessage();
564 }
565 currentLimit = byteLimit;
566
567 RecomputeBufferSizeAfterLimit();
568
569 return oldLimit;
570 }
571
572 private void RecomputeBufferSizeAfterLimit() {
573 bufferSize += bufferSizeAfterLimit;
574 int bufferEnd = totalBytesRetired + bufferSize;
575 if (bufferEnd > currentLimit) {
576 // Limit is in current buffer.
577 bufferSizeAfterLimit = bufferEnd - currentLimit;
578 bufferSize -= bufferSizeAfterLimit;
579 } else {
580 bufferSizeAfterLimit = 0;
581 }
582 }
583
584 /// <summary>
585 /// Discards the current limit, returning the previous limit.
586 /// </summary>
587 public void PopLimit(int oldLimit) {
588 currentLimit = oldLimit;
589 RecomputeBufferSizeAfterLimit();
590 }
591
592 /// <summary>
593 /// Called when buffer is empty to read more bytes from the
594 /// input. If <paramref name="mustSucceed"/> is true, RefillBuffer() gurantees that
595 /// either there will be at least one byte in the buffer when it returns
596 /// or it will throw an exception. If <paramref name="mustSucceed"/> is false,
597 /// RefillBuffer() returns false if no more bytes were available.
598 /// </summary>
599 /// <param name="mustSucceed"></param>
600 /// <returns></returns>
601 private bool RefillBuffer(bool mustSucceed) {
602 if (bufferPos < bufferSize) {
603 throw new InvalidOperationException("RefillBuffer() called when buffer wasn't empty.");
604 }
605
606 if (totalBytesRetired + bufferSize == currentLimit) {
607 // Oops, we hit a limit.
608 if (mustSucceed) {
609 throw InvalidProtocolBufferException.TruncatedMessage();
610 } else {
611 return false;
612 }
613 }
614
615 totalBytesRetired += bufferSize;
616
617 bufferPos = 0;
618 bufferSize = (input == null) ? 0 : input.Read(buffer, 0, buffer.Length);
619 if (bufferSize == 0) {
620 if (mustSucceed) {
621 throw InvalidProtocolBufferException.TruncatedMessage();
622 } else {
623 return false;
624 }
625 } else {
626 RecomputeBufferSizeAfterLimit();
627 int totalBytesRead =
628 totalBytesRetired + bufferSize + bufferSizeAfterLimit;
629 if (totalBytesRead > sizeLimit || totalBytesRead < 0) {
630 throw InvalidProtocolBufferException.SizeLimitExceeded();
631 }
632 return true;
633 }
634 }
635
636 /// <summary>
637 /// Read one byte from the input.
638 /// </summary>
639 /// <exception cref="InvalidProtocolBufferException">
640 /// he end of the stream or the current limit was reached
641 /// </exception>
642 public byte ReadRawByte() {
643 if (bufferPos == bufferSize) {
644 RefillBuffer(true);
645 }
646 return buffer[bufferPos++];
647 }
648
649 /// <summary>
650 /// Read a fixed size of bytes from the input.
651 /// </summary>
652 /// <exception cref="InvalidProtocolBufferException">
653 /// the end of the stream or the current limit was reached
654 /// </exception>
655 public byte[] ReadRawBytes(int size) {
656 if (size < 0) {
657 throw InvalidProtocolBufferException.NegativeSize();
658 }
659
660 if (totalBytesRetired + bufferPos + size > currentLimit) {
661 // Read to the end of the stream anyway.
662 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
663 // Then fail.
664 throw InvalidProtocolBufferException.TruncatedMessage();
665 }
666
667 if (size <= bufferSize - bufferPos) {
668 // We have all the bytes we need already.
669 byte[] bytes = new byte[size];
670 Array.Copy(buffer, bufferPos, bytes, 0, size);
671 bufferPos += size;
672 return bytes;
673 } else if (size < BufferSize) {
674 // Reading more bytes than are in the buffer, but not an excessive number
675 // of bytes. We can safely allocate the resulting array ahead of time.
676
677 // First copy what we have.
678 byte[] bytes = new byte[size];
679 int pos = bufferSize - bufferPos;
680 Array.Copy(buffer, bufferPos, bytes, 0, pos);
681 bufferPos = bufferSize;
682
683 // We want to use RefillBuffer() and then copy from the buffer into our
684 // byte array rather than reading directly into our byte array because
685 // the input may be unbuffered.
686 RefillBuffer(true);
687
688 while (size - pos > bufferSize) {
689 Array.Copy(buffer, 0, bytes, pos, bufferSize);
690 pos += bufferSize;
691 bufferPos = bufferSize;
692 RefillBuffer(true);
693 }
694
695 Array.Copy(buffer, 0, bytes, pos, size - pos);
696 bufferPos = size - pos;
697
698 return bytes;
699 } else {
700 // The size is very large. For security reasons, we can't allocate the
701 // entire byte array yet. The size comes directly from the input, so a
702 // maliciously-crafted message could provide a bogus very large size in
703 // order to trick the app into allocating a lot of memory. We avoid this
704 // by allocating and reading only a small chunk at a time, so that the
705 // malicious message must actually *be* extremely large to cause
706 // problems. Meanwhile, we limit the allowed size of a message elsewhere.
707
708 // Remember the buffer markers since we'll have to copy the bytes out of
709 // it later.
710 int originalBufferPos = bufferPos;
711 int originalBufferSize = bufferSize;
712
713 // Mark the current buffer consumed.
714 totalBytesRetired += bufferSize;
715 bufferPos = 0;
716 bufferSize = 0;
717
718 // Read all the rest of the bytes we need.
719 int sizeLeft = size - (originalBufferSize - originalBufferPos);
720 List<byte[]> chunks = new List<byte[]>();
721
722 while (sizeLeft > 0) {
723 byte[] chunk = new byte[Math.Min(sizeLeft, BufferSize)];
724 int pos = 0;
725 while (pos < chunk.Length) {
726 int n = (input == null) ? -1 : input.Read(chunk, pos, chunk.Length - pos);
727 if (n <= 0) {
728 throw InvalidProtocolBufferException.TruncatedMessage();
729 }
730 totalBytesRetired += n;
731 pos += n;
732 }
733 sizeLeft -= chunk.Length;
734 chunks.Add(chunk);
735 }
736
737 // OK, got everything. Now concatenate it all into one buffer.
738 byte[] bytes = new byte[size];
739
740 // Start by copying the leftover bytes from this.buffer.
741 int newPos = originalBufferSize - originalBufferPos;
742 Array.Copy(buffer, originalBufferPos, bytes, 0, newPos);
743
744 // And now all the chunks.
745 foreach (byte[] chunk in chunks) {
746 Array.Copy(chunk, 0, bytes, newPos, chunk.Length);
747 newPos += chunk.Length;
748 }
749
750 // Done.
751 return bytes;
752 }
753 }
754
755 /// <summary>
756 /// Reads and discards a single field, given its tag value.
757 /// </summary>
758 /// <returns>false if the tag is an end-group tag, in which case
759 /// nothing is skipped. Otherwise, returns true.</returns>
760 public bool SkipField(uint tag) {
761 switch (WireFormat.GetTagWireType(tag)) {
762 case WireFormat.WireType.Varint:
763 ReadInt32();
764 return true;
765 case WireFormat.WireType.Fixed64:
766 ReadRawLittleEndian64();
767 return true;
768 case WireFormat.WireType.LengthDelimited:
769 SkipRawBytes((int) ReadRawVarint32());
770 return true;
771 case WireFormat.WireType.StartGroup:
772 SkipMessage();
773 CheckLastTagWas(
774 WireFormat.MakeTag(WireFormat.GetTagFieldNumber(tag),
775 WireFormat.WireType.EndGroup));
776 return true;
777 case WireFormat.WireType.EndGroup:
778 return false;
779 case WireFormat.WireType.Fixed32:
780 ReadRawLittleEndian32();
781 return true;
782 default:
783 throw InvalidProtocolBufferException.InvalidWireType();
784 }
785 }
786
787 /// <summary>
788 /// Reads and discards an entire message. This will read either until EOF
789 /// or until an endgroup tag, whichever comes first.
790 /// </summary>
791 public void SkipMessage() {
792 while (true) {
793 uint tag = ReadTag();
794 if (tag == 0 || !SkipField(tag)) {
795 return;
796 }
797 }
798 }
799
800 /// <summary>
801 /// Reads and discards <paramref name="size"/> bytes.
802 /// </summary>
803 /// <exception cref="InvalidProtocolBufferException">the end of the stream
804 /// or the current limit was reached</exception>
805 public void SkipRawBytes(int size) {
806 if (size < 0) {
807 throw InvalidProtocolBufferException.NegativeSize();
808 }
809
810 if (totalBytesRetired + bufferPos + size > currentLimit) {
811 // Read to the end of the stream anyway.
812 SkipRawBytes(currentLimit - totalBytesRetired - bufferPos);
813 // Then fail.
814 throw InvalidProtocolBufferException.TruncatedMessage();
815 }
816
817 if (size < bufferSize - bufferPos) {
818 // We have all the bytes we need already.
819 bufferPos += size;
820 } else {
821 // Skipping more bytes than are in the buffer. First skip what we have.
822 int pos = bufferSize - bufferPos;
823 totalBytesRetired += pos;
824 bufferPos = 0;
825 bufferSize = 0;
826
827 // Then skip directly from the InputStream for the rest.
828 if (pos < size) {
829 // TODO(jonskeet): Java implementation uses skip(). Not sure whether this is really equivalent...
830 if (input == null) {
831 throw InvalidProtocolBufferException.TruncatedMessage();
832 }
833 input.Seek(size - pos, SeekOrigin.Current);
834 if (input.Position > input.Length) {
835 throw InvalidProtocolBufferException.TruncatedMessage();
836 }
837 totalBytesRetired += size - pos;
838 }
839 }
840 }
841 #endregion
842 }
843}