blob: 37bc35336adcad1647420f49db2a02d368a5d85a [file] [log] [blame]
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -07001use std::cmp;
2use std::io::BufRead;
3use std::io::BufReader;
4use std::io::Read;
5use std::mem;
6use std::u64;
7
8#[cfg(feature = "bytes")]
Haibo Huang914311b2021-01-07 18:06:15 -08009use bytes::buf::UninitSlice;
10#[cfg(feature = "bytes")]
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -070011use bytes::BufMut;
12#[cfg(feature = "bytes")]
13use bytes::Bytes;
14#[cfg(feature = "bytes")]
15use bytes::BytesMut;
16
Joel Galensonfa77f002021-04-02 11:32:01 -070017use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
Haibo Huangd32e6ee2020-08-12 13:52:04 -070018use crate::error::WireError;
Haibo Huangd32e6ee2020-08-12 13:52:04 -070019use crate::ProtobufError;
20use crate::ProtobufResult;
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -070021
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -070022// If an input stream is constructed with a `Read`, we create a
23// `BufReader` with an internal buffer of this size.
24const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
25
26const USE_UNSAFE_FOR_SPEED: bool = true;
27
28const NO_LIMIT: u64 = u64::MAX;
29
30/// Hold all possible combinations of input source
31enum InputSource<'a> {
Haibo Huang4bf8b462020-11-24 20:53:50 -080032 BufRead(&'a mut dyn BufRead),
33 Read(BufReader<&'a mut dyn Read>),
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -070034 Slice(&'a [u8]),
35 #[cfg(feature = "bytes")]
36 Bytes(&'a Bytes),
37}
38
39/// Dangerous implementation of `BufRead`.
40///
41/// Unsafe wrapper around BufRead which assumes that `BufRead` buf is
42/// not moved when `BufRead` is moved.
43///
44/// This assumption is generally incorrect, however, in practice
45/// `BufReadIter` is created either from `BufRead` reference (which
46/// cannot be moved, because it is locked by `CodedInputStream`) or from
47/// `BufReader` which does not move its buffer (we know that from
48/// inspecting rust standard library).
49///
50/// It is important for `CodedInputStream` performance that small reads
51/// (e. g. 4 bytes reads) do not involve virtual calls or switches.
52/// This is achievable with `BufReadIter`.
53pub struct BufReadIter<'a> {
54 input_source: InputSource<'a>,
55 buf: &'a [u8],
56 pos_within_buf: usize,
57 limit_within_buf: usize,
58 pos_of_buf_start: u64,
59 limit: u64,
60}
61
62impl<'a> Drop for BufReadIter<'a> {
63 fn drop(&mut self) {
64 match self.input_source {
65 InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf),
66 InputSource::Read(_) => {
67 // Nothing to flush, because we own BufReader
68 }
69 _ => {}
70 }
71 }
72}
73
74impl<'ignore> BufReadIter<'ignore> {
75 pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> {
76 BufReadIter {
77 input_source: InputSource::Read(BufReader::with_capacity(
78 INPUT_STREAM_BUFFER_SIZE,
79 read,
80 )),
81 buf: &[],
82 pos_within_buf: 0,
83 limit_within_buf: 0,
84 pos_of_buf_start: 0,
85 limit: NO_LIMIT,
86 }
87 }
88
89 pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
90 BufReadIter {
91 input_source: InputSource::BufRead(buf_read),
92 buf: &[],
93 pos_within_buf: 0,
94 limit_within_buf: 0,
95 pos_of_buf_start: 0,
96 limit: NO_LIMIT,
97 }
98 }
99
100 pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> {
101 BufReadIter {
102 input_source: InputSource::Slice(bytes),
103 buf: bytes,
104 pos_within_buf: 0,
105 limit_within_buf: bytes.len(),
106 pos_of_buf_start: 0,
107 limit: NO_LIMIT,
108 }
109 }
110
111 #[cfg(feature = "bytes")]
112 pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> {
113 BufReadIter {
114 input_source: InputSource::Bytes(bytes),
115 buf: &bytes,
116 pos_within_buf: 0,
117 limit_within_buf: bytes.len(),
118 pos_of_buf_start: 0,
119 limit: NO_LIMIT,
120 }
121 }
122
123 #[inline]
124 fn assertions(&self) {
125 debug_assert!(self.pos_within_buf <= self.limit_within_buf);
126 debug_assert!(self.limit_within_buf <= self.buf.len());
127 debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit);
128 }
129
130 #[inline(always)]
131 pub fn pos(&self) -> u64 {
132 self.pos_of_buf_start + self.pos_within_buf as u64
133 }
134
135 /// Recompute `limit_within_buf` after update of `limit`
136 #[inline]
137 fn update_limit_within_buf(&mut self) {
138 if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit {
139 self.limit_within_buf = self.buf.len();
140 } else {
141 self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize;
142 }
143
144 self.assertions();
145 }
146
147 pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> {
148 let new_limit = match self.pos().checked_add(limit) {
149 Some(new_limit) => new_limit,
150 None => return Err(ProtobufError::WireError(WireError::Other)),
151 };
152
153 if new_limit > self.limit {
154 return Err(ProtobufError::WireError(WireError::Other));
155 }
156
157 let prev_limit = mem::replace(&mut self.limit, new_limit);
158
159 self.update_limit_within_buf();
160
161 Ok(prev_limit)
162 }
163
164 #[inline]
165 pub fn pop_limit(&mut self, limit: u64) {
166 assert!(limit >= self.limit);
167
168 self.limit = limit;
169
170 self.update_limit_within_buf();
171 }
172
173 #[inline]
174 pub fn remaining_in_buf(&self) -> &[u8] {
175 if USE_UNSAFE_FOR_SPEED {
176 unsafe {
177 &self
178 .buf
179 .get_unchecked(self.pos_within_buf..self.limit_within_buf)
180 }
181 } else {
182 &self.buf[self.pos_within_buf..self.limit_within_buf]
183 }
184 }
185
186 #[inline(always)]
187 pub fn remaining_in_buf_len(&self) -> usize {
188 self.limit_within_buf - self.pos_within_buf
189 }
190
191 #[inline(always)]
192 pub fn bytes_until_limit(&self) -> u64 {
193 if self.limit == NO_LIMIT {
194 NO_LIMIT
195 } else {
196 self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64)
197 }
198 }
199
200 #[inline(always)]
201 pub fn eof(&mut self) -> ProtobufResult<bool> {
202 if self.pos_within_buf == self.limit_within_buf {
203 Ok(self.fill_buf()?.is_empty())
204 } else {
205 Ok(false)
206 }
207 }
208
209 #[inline(always)]
210 pub fn read_byte(&mut self) -> ProtobufResult<u8> {
211 if self.pos_within_buf == self.limit_within_buf {
212 self.do_fill_buf()?;
213 if self.remaining_in_buf_len() == 0 {
214 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
215 }
216 }
217
218 let r = if USE_UNSAFE_FOR_SPEED {
219 unsafe { *self.buf.get_unchecked(self.pos_within_buf) }
220 } else {
221 self.buf[self.pos_within_buf]
222 };
223 self.pos_within_buf += 1;
224 Ok(r)
225 }
226
227 /// Read at most `max` bytes, append to `Vec`.
228 ///
229 /// Returns 0 when EOF or limit reached.
230 fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> {
231 let len = {
232 let rem = self.fill_buf()?;
233
234 let len = cmp::min(rem.len(), max);
235 vec.extend_from_slice(&rem[..len]);
236 len
237 };
238 self.pos_within_buf += len;
239 Ok(len)
240 }
241
242 /// Read exact number of bytes into `Vec`.
243 ///
244 /// `Vec` is cleared in the beginning.
245 pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> {
246 // TODO: also do some limits when reading from unlimited source
247 if count as u64 > self.bytes_until_limit() {
248 return Err(ProtobufError::WireError(WireError::TruncatedMessage));
249 }
250
251 target.clear();
252
253 if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
254 // avoid calling `reserve` on buf with very large buffer: could be a malformed message
255
256 target.reserve(READ_RAW_BYTES_MAX_ALLOC);
257
258 while target.len() < count {
259 let need_to_read = count - target.len();
260 if need_to_read <= target.len() {
261 target.reserve_exact(need_to_read);
262 } else {
263 target.reserve(1);
264 }
265
266 let max = cmp::min(target.capacity() - target.len(), need_to_read);
267 let read = self.read_to_vec(target, max)?;
268 if read == 0 {
269 return Err(ProtobufError::WireError(WireError::TruncatedMessage));
270 }
271 }
272 } else {
273 target.reserve_exact(count);
274
275 unsafe {
276 self.read_exact(&mut target.get_unchecked_mut(..count))?;
277 target.set_len(count);
278 }
279 }
280
281 debug_assert_eq!(count, target.len());
282
283 Ok(())
284 }
285
286 #[cfg(feature = "bytes")]
287 pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> {
288 if let InputSource::Bytes(bytes) = self.input_source {
289 let end = match self.pos_within_buf.checked_add(len) {
290 Some(end) => end,
291 None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)),
292 };
293
294 if end > self.limit_within_buf {
295 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
296 }
297
298 let r = bytes.slice(self.pos_within_buf..end);
299 self.pos_within_buf += len;
300 Ok(r)
301 } else {
302 if len >= READ_RAW_BYTES_MAX_ALLOC {
303 // We cannot trust `len` because protobuf message could be malformed.
304 // Reading should not result in OOM when allocating a buffer.
305 let mut v = Vec::new();
306 self.read_exact_to_vec(len, &mut v)?;
307 Ok(Bytes::from(v))
308 } else {
309 let mut r = BytesMut::with_capacity(len);
310 unsafe {
Haibo Huang914311b2021-01-07 18:06:15 -0800311 let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -0700312 self.read_exact(buf)?;
313 r.advance_mut(len);
314 }
315 Ok(r.freeze())
316 }
317 }
318 }
319
Haibo Huang914311b2021-01-07 18:06:15 -0800320 #[cfg(feature = "bytes")]
321 unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] {
322 use std::slice;
323 slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len())
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -0700324 }
325
326 /// Returns 0 when EOF or limit reached.
327 pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> {
328 self.fill_buf()?;
329
330 let rem = &self.buf[self.pos_within_buf..self.limit_within_buf];
331
332 let len = cmp::min(rem.len(), buf.len());
Joel Galensonf9dc51b2021-08-09 10:39:22 -0700333 buf[..len].copy_from_slice(&rem[..len]);
Chih-Hung Hsiehcfc3a232020-06-10 20:13:05 -0700334 self.pos_within_buf += len;
335 Ok(len)
336 }
337
338 pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> {
339 if self.remaining_in_buf_len() >= buf.len() {
340 let buf_len = buf.len();
341 buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]);
342 self.pos_within_buf += buf_len;
343 return Ok(());
344 }
345
346 if self.bytes_until_limit() < buf.len() as u64 {
347 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
348 }
349
350 let consume = self.pos_within_buf;
351 self.pos_of_buf_start += self.pos_within_buf as u64;
352 self.pos_within_buf = 0;
353 self.buf = &[];
354 self.limit_within_buf = 0;
355
356 match self.input_source {
357 InputSource::Read(ref mut buf_read) => {
358 buf_read.consume(consume);
359 buf_read.read_exact(buf)?;
360 }
361 InputSource::BufRead(ref mut buf_read) => {
362 buf_read.consume(consume);
363 buf_read.read_exact(buf)?;
364 }
365 _ => {
366 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
367 }
368 }
369
370 self.pos_of_buf_start += buf.len() as u64;
371
372 self.assertions();
373
374 Ok(())
375 }
376
377 fn do_fill_buf(&mut self) -> ProtobufResult<()> {
378 debug_assert!(self.pos_within_buf == self.limit_within_buf);
379
380 // Limit is reached, do not fill buf, because otherwise
381 // synchronous read from `CodedInputStream` may block.
382 if self.limit == self.pos() {
383 return Ok(());
384 }
385
386 let consume = self.buf.len();
387 self.pos_of_buf_start += self.buf.len() as u64;
388 self.buf = &[];
389 self.pos_within_buf = 0;
390 self.limit_within_buf = 0;
391
392 match self.input_source {
393 InputSource::Read(ref mut buf_read) => {
394 buf_read.consume(consume);
395 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
396 }
397 InputSource::BufRead(ref mut buf_read) => {
398 buf_read.consume(consume);
399 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
400 }
401 _ => {
402 return Ok(());
403 }
404 }
405
406 self.update_limit_within_buf();
407
408 Ok(())
409 }
410
411 #[inline(always)]
412 pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> {
413 if self.pos_within_buf == self.limit_within_buf {
414 self.do_fill_buf()?;
415 }
416
417 Ok(if USE_UNSAFE_FOR_SPEED {
418 unsafe {
419 self.buf
420 .get_unchecked(self.pos_within_buf..self.limit_within_buf)
421 }
422 } else {
423 &self.buf[self.pos_within_buf..self.limit_within_buf]
424 })
425 }
426
427 #[inline(always)]
428 pub fn consume(&mut self, amt: usize) {
429 assert!(amt <= self.limit_within_buf - self.pos_within_buf);
430 self.pos_within_buf += amt;
431 }
432}
433
434#[cfg(all(test, feature = "bytes"))]
435mod test_bytes {
436 use super::*;
437 use std::io::Write;
438
439 fn make_long_string(len: usize) -> Vec<u8> {
440 let mut s = Vec::new();
441 while s.len() < len {
442 let len = s.len();
443 write!(&mut s, "{}", len).expect("unexpected");
444 }
445 s.truncate(len);
446 s
447 }
448
449 #[test]
450 fn read_exact_bytes_from_slice() {
451 let bytes = make_long_string(100);
452 let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
453 assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
454 assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
455 }
456
457 #[test]
458 fn read_exact_bytes_from_bytes() {
459 let bytes = Bytes::from(make_long_string(100));
460 let mut bri = BufReadIter::from_bytes(&bytes);
461 let read = bri.read_exact_bytes(90).unwrap();
462 assert_eq!(&bytes[..90], &read[..]);
463 assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
464 assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
465 }
466}
467
468#[cfg(test)]
469mod test {
470 use super::*;
471 use std::io;
472 use std::io::BufRead;
473 use std::io::Read;
474
475 #[test]
476 fn eof_at_limit() {
477 struct Read5ThenPanic {
478 pos: usize,
479 }
480
481 impl Read for Read5ThenPanic {
482 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
483 unreachable!();
484 }
485 }
486
487 impl BufRead for Read5ThenPanic {
488 fn fill_buf(&mut self) -> io::Result<&[u8]> {
489 assert_eq!(0, self.pos);
490 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
491 Ok(ZERO_TO_FIVE)
492 }
493
494 fn consume(&mut self, amt: usize) {
495 if amt == 0 {
496 // drop of BufReadIter
497 return;
498 }
499
500 assert_eq!(0, self.pos);
501 assert_eq!(5, amt);
502 self.pos += amt;
503 }
504 }
505
506 let mut read = Read5ThenPanic { pos: 0 };
507 let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
508 assert_eq!(0, buf_read_iter.pos());
509 let _prev_limit = buf_read_iter.push_limit(5);
510 buf_read_iter.read_byte().expect("read_byte");
511 buf_read_iter
512 .read_exact(&mut [1, 2, 3, 4])
513 .expect("read_exact");
514 assert!(buf_read_iter.eof().expect("eof"));
515 }
516}