Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 1 | use std::cmp; |
| 2 | use std::io::BufRead; |
| 3 | use std::io::BufReader; |
| 4 | use std::io::Read; |
| 5 | use std::mem; |
| 6 | use std::u64; |
| 7 | |
| 8 | #[cfg(feature = "bytes")] |
Haibo Huang | 914311b | 2021-01-07 18:06:15 -0800 | [diff] [blame] | 9 | use bytes::buf::UninitSlice; |
| 10 | #[cfg(feature = "bytes")] |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 11 | use bytes::BufMut; |
| 12 | #[cfg(feature = "bytes")] |
| 13 | use bytes::Bytes; |
| 14 | #[cfg(feature = "bytes")] |
| 15 | use bytes::BytesMut; |
| 16 | |
Joel Galenson | fa77f00 | 2021-04-02 11:32:01 -0700 | [diff] [blame] | 17 | use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC; |
Haibo Huang | d32e6ee | 2020-08-12 13:52:04 -0700 | [diff] [blame] | 18 | use crate::error::WireError; |
Haibo Huang | d32e6ee | 2020-08-12 13:52:04 -0700 | [diff] [blame] | 19 | use crate::ProtobufError; |
| 20 | use crate::ProtobufResult; |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 21 | |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 22 | // If an input stream is constructed with a `Read`, we create a |
| 23 | // `BufReader` with an internal buffer of this size. |
| 24 | const INPUT_STREAM_BUFFER_SIZE: usize = 4096; |
| 25 | |
| 26 | const USE_UNSAFE_FOR_SPEED: bool = true; |
| 27 | |
| 28 | const NO_LIMIT: u64 = u64::MAX; |
| 29 | |
| 30 | /// Hold all possible combinations of input source |
| 31 | enum InputSource<'a> { |
Haibo Huang | 4bf8b46 | 2020-11-24 20:53:50 -0800 | [diff] [blame] | 32 | BufRead(&'a mut dyn BufRead), |
| 33 | Read(BufReader<&'a mut dyn Read>), |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 34 | Slice(&'a [u8]), |
| 35 | #[cfg(feature = "bytes")] |
| 36 | Bytes(&'a Bytes), |
| 37 | } |
| 38 | |
| 39 | /// Dangerous implementation of `BufRead`. |
| 40 | /// |
| 41 | /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is |
| 42 | /// not moved when `BufRead` is moved. |
| 43 | /// |
| 44 | /// This assumption is generally incorrect, however, in practice |
| 45 | /// `BufReadIter` is created either from `BufRead` reference (which |
| 46 | /// cannot be moved, because it is locked by `CodedInputStream`) or from |
| 47 | /// `BufReader` which does not move its buffer (we know that from |
| 48 | /// inspecting rust standard library). |
| 49 | /// |
| 50 | /// It is important for `CodedInputStream` performance that small reads |
| 51 | /// (e. g. 4 bytes reads) do not involve virtual calls or switches. |
| 52 | /// This is achievable with `BufReadIter`. |
| 53 | pub struct BufReadIter<'a> { |
| 54 | input_source: InputSource<'a>, |
| 55 | buf: &'a [u8], |
| 56 | pos_within_buf: usize, |
| 57 | limit_within_buf: usize, |
| 58 | pos_of_buf_start: u64, |
| 59 | limit: u64, |
| 60 | } |
| 61 | |
| 62 | impl<'a> Drop for BufReadIter<'a> { |
| 63 | fn drop(&mut self) { |
| 64 | match self.input_source { |
| 65 | InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf), |
| 66 | InputSource::Read(_) => { |
| 67 | // Nothing to flush, because we own BufReader |
| 68 | } |
| 69 | _ => {} |
| 70 | } |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | impl<'ignore> BufReadIter<'ignore> { |
| 75 | pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> { |
| 76 | BufReadIter { |
| 77 | input_source: InputSource::Read(BufReader::with_capacity( |
| 78 | INPUT_STREAM_BUFFER_SIZE, |
| 79 | read, |
| 80 | )), |
| 81 | buf: &[], |
| 82 | pos_within_buf: 0, |
| 83 | limit_within_buf: 0, |
| 84 | pos_of_buf_start: 0, |
| 85 | limit: NO_LIMIT, |
| 86 | } |
| 87 | } |
| 88 | |
| 89 | pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { |
| 90 | BufReadIter { |
| 91 | input_source: InputSource::BufRead(buf_read), |
| 92 | buf: &[], |
| 93 | pos_within_buf: 0, |
| 94 | limit_within_buf: 0, |
| 95 | pos_of_buf_start: 0, |
| 96 | limit: NO_LIMIT, |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> { |
| 101 | BufReadIter { |
| 102 | input_source: InputSource::Slice(bytes), |
| 103 | buf: bytes, |
| 104 | pos_within_buf: 0, |
| 105 | limit_within_buf: bytes.len(), |
| 106 | pos_of_buf_start: 0, |
| 107 | limit: NO_LIMIT, |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | #[cfg(feature = "bytes")] |
| 112 | pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> { |
| 113 | BufReadIter { |
| 114 | input_source: InputSource::Bytes(bytes), |
| 115 | buf: &bytes, |
| 116 | pos_within_buf: 0, |
| 117 | limit_within_buf: bytes.len(), |
| 118 | pos_of_buf_start: 0, |
| 119 | limit: NO_LIMIT, |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | #[inline] |
| 124 | fn assertions(&self) { |
| 125 | debug_assert!(self.pos_within_buf <= self.limit_within_buf); |
| 126 | debug_assert!(self.limit_within_buf <= self.buf.len()); |
| 127 | debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit); |
| 128 | } |
| 129 | |
| 130 | #[inline(always)] |
| 131 | pub fn pos(&self) -> u64 { |
| 132 | self.pos_of_buf_start + self.pos_within_buf as u64 |
| 133 | } |
| 134 | |
| 135 | /// Recompute `limit_within_buf` after update of `limit` |
| 136 | #[inline] |
| 137 | fn update_limit_within_buf(&mut self) { |
| 138 | if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit { |
| 139 | self.limit_within_buf = self.buf.len(); |
| 140 | } else { |
| 141 | self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize; |
| 142 | } |
| 143 | |
| 144 | self.assertions(); |
| 145 | } |
| 146 | |
| 147 | pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> { |
| 148 | let new_limit = match self.pos().checked_add(limit) { |
| 149 | Some(new_limit) => new_limit, |
| 150 | None => return Err(ProtobufError::WireError(WireError::Other)), |
| 151 | }; |
| 152 | |
| 153 | if new_limit > self.limit { |
| 154 | return Err(ProtobufError::WireError(WireError::Other)); |
| 155 | } |
| 156 | |
| 157 | let prev_limit = mem::replace(&mut self.limit, new_limit); |
| 158 | |
| 159 | self.update_limit_within_buf(); |
| 160 | |
| 161 | Ok(prev_limit) |
| 162 | } |
| 163 | |
| 164 | #[inline] |
| 165 | pub fn pop_limit(&mut self, limit: u64) { |
| 166 | assert!(limit >= self.limit); |
| 167 | |
| 168 | self.limit = limit; |
| 169 | |
| 170 | self.update_limit_within_buf(); |
| 171 | } |
| 172 | |
| 173 | #[inline] |
| 174 | pub fn remaining_in_buf(&self) -> &[u8] { |
| 175 | if USE_UNSAFE_FOR_SPEED { |
| 176 | unsafe { |
| 177 | &self |
| 178 | .buf |
| 179 | .get_unchecked(self.pos_within_buf..self.limit_within_buf) |
| 180 | } |
| 181 | } else { |
| 182 | &self.buf[self.pos_within_buf..self.limit_within_buf] |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | #[inline(always)] |
| 187 | pub fn remaining_in_buf_len(&self) -> usize { |
| 188 | self.limit_within_buf - self.pos_within_buf |
| 189 | } |
| 190 | |
| 191 | #[inline(always)] |
| 192 | pub fn bytes_until_limit(&self) -> u64 { |
| 193 | if self.limit == NO_LIMIT { |
| 194 | NO_LIMIT |
| 195 | } else { |
| 196 | self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64) |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | #[inline(always)] |
| 201 | pub fn eof(&mut self) -> ProtobufResult<bool> { |
| 202 | if self.pos_within_buf == self.limit_within_buf { |
| 203 | Ok(self.fill_buf()?.is_empty()) |
| 204 | } else { |
| 205 | Ok(false) |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | #[inline(always)] |
| 210 | pub fn read_byte(&mut self) -> ProtobufResult<u8> { |
| 211 | if self.pos_within_buf == self.limit_within_buf { |
| 212 | self.do_fill_buf()?; |
| 213 | if self.remaining_in_buf_len() == 0 { |
| 214 | return Err(ProtobufError::WireError(WireError::UnexpectedEof)); |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | let r = if USE_UNSAFE_FOR_SPEED { |
| 219 | unsafe { *self.buf.get_unchecked(self.pos_within_buf) } |
| 220 | } else { |
| 221 | self.buf[self.pos_within_buf] |
| 222 | }; |
| 223 | self.pos_within_buf += 1; |
| 224 | Ok(r) |
| 225 | } |
| 226 | |
| 227 | /// Read at most `max` bytes, append to `Vec`. |
| 228 | /// |
| 229 | /// Returns 0 when EOF or limit reached. |
| 230 | fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> { |
| 231 | let len = { |
| 232 | let rem = self.fill_buf()?; |
| 233 | |
| 234 | let len = cmp::min(rem.len(), max); |
| 235 | vec.extend_from_slice(&rem[..len]); |
| 236 | len |
| 237 | }; |
| 238 | self.pos_within_buf += len; |
| 239 | Ok(len) |
| 240 | } |
| 241 | |
| 242 | /// Read exact number of bytes into `Vec`. |
| 243 | /// |
| 244 | /// `Vec` is cleared in the beginning. |
| 245 | pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> { |
| 246 | // TODO: also do some limits when reading from unlimited source |
| 247 | if count as u64 > self.bytes_until_limit() { |
| 248 | return Err(ProtobufError::WireError(WireError::TruncatedMessage)); |
| 249 | } |
| 250 | |
| 251 | target.clear(); |
| 252 | |
| 253 | if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { |
| 254 | // avoid calling `reserve` on buf with very large buffer: could be a malformed message |
| 255 | |
| 256 | target.reserve(READ_RAW_BYTES_MAX_ALLOC); |
| 257 | |
| 258 | while target.len() < count { |
| 259 | let need_to_read = count - target.len(); |
| 260 | if need_to_read <= target.len() { |
| 261 | target.reserve_exact(need_to_read); |
| 262 | } else { |
| 263 | target.reserve(1); |
| 264 | } |
| 265 | |
| 266 | let max = cmp::min(target.capacity() - target.len(), need_to_read); |
| 267 | let read = self.read_to_vec(target, max)?; |
| 268 | if read == 0 { |
| 269 | return Err(ProtobufError::WireError(WireError::TruncatedMessage)); |
| 270 | } |
| 271 | } |
| 272 | } else { |
| 273 | target.reserve_exact(count); |
| 274 | |
| 275 | unsafe { |
| 276 | self.read_exact(&mut target.get_unchecked_mut(..count))?; |
| 277 | target.set_len(count); |
| 278 | } |
| 279 | } |
| 280 | |
| 281 | debug_assert_eq!(count, target.len()); |
| 282 | |
| 283 | Ok(()) |
| 284 | } |
| 285 | |
| 286 | #[cfg(feature = "bytes")] |
| 287 | pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> { |
| 288 | if let InputSource::Bytes(bytes) = self.input_source { |
| 289 | let end = match self.pos_within_buf.checked_add(len) { |
| 290 | Some(end) => end, |
| 291 | None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)), |
| 292 | }; |
| 293 | |
| 294 | if end > self.limit_within_buf { |
| 295 | return Err(ProtobufError::WireError(WireError::UnexpectedEof)); |
| 296 | } |
| 297 | |
| 298 | let r = bytes.slice(self.pos_within_buf..end); |
| 299 | self.pos_within_buf += len; |
| 300 | Ok(r) |
| 301 | } else { |
| 302 | if len >= READ_RAW_BYTES_MAX_ALLOC { |
| 303 | // We cannot trust `len` because protobuf message could be malformed. |
| 304 | // Reading should not result in OOM when allocating a buffer. |
| 305 | let mut v = Vec::new(); |
| 306 | self.read_exact_to_vec(len, &mut v)?; |
| 307 | Ok(Bytes::from(v)) |
| 308 | } else { |
| 309 | let mut r = BytesMut::with_capacity(len); |
| 310 | unsafe { |
Haibo Huang | 914311b | 2021-01-07 18:06:15 -0800 | [diff] [blame] | 311 | let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]); |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 312 | self.read_exact(buf)?; |
| 313 | r.advance_mut(len); |
| 314 | } |
| 315 | Ok(r.freeze()) |
| 316 | } |
| 317 | } |
| 318 | } |
| 319 | |
Haibo Huang | 914311b | 2021-01-07 18:06:15 -0800 | [diff] [blame] | 320 | #[cfg(feature = "bytes")] |
| 321 | unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] { |
| 322 | use std::slice; |
| 323 | slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len()) |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 324 | } |
| 325 | |
| 326 | /// Returns 0 when EOF or limit reached. |
| 327 | pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> { |
| 328 | self.fill_buf()?; |
| 329 | |
| 330 | let rem = &self.buf[self.pos_within_buf..self.limit_within_buf]; |
| 331 | |
| 332 | let len = cmp::min(rem.len(), buf.len()); |
Joel Galenson | f9dc51b | 2021-08-09 10:39:22 -0700 | [diff] [blame] | 333 | buf[..len].copy_from_slice(&rem[..len]); |
Chih-Hung Hsieh | cfc3a23 | 2020-06-10 20:13:05 -0700 | [diff] [blame] | 334 | self.pos_within_buf += len; |
| 335 | Ok(len) |
| 336 | } |
| 337 | |
| 338 | pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> { |
| 339 | if self.remaining_in_buf_len() >= buf.len() { |
| 340 | let buf_len = buf.len(); |
| 341 | buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]); |
| 342 | self.pos_within_buf += buf_len; |
| 343 | return Ok(()); |
| 344 | } |
| 345 | |
| 346 | if self.bytes_until_limit() < buf.len() as u64 { |
| 347 | return Err(ProtobufError::WireError(WireError::UnexpectedEof)); |
| 348 | } |
| 349 | |
| 350 | let consume = self.pos_within_buf; |
| 351 | self.pos_of_buf_start += self.pos_within_buf as u64; |
| 352 | self.pos_within_buf = 0; |
| 353 | self.buf = &[]; |
| 354 | self.limit_within_buf = 0; |
| 355 | |
| 356 | match self.input_source { |
| 357 | InputSource::Read(ref mut buf_read) => { |
| 358 | buf_read.consume(consume); |
| 359 | buf_read.read_exact(buf)?; |
| 360 | } |
| 361 | InputSource::BufRead(ref mut buf_read) => { |
| 362 | buf_read.consume(consume); |
| 363 | buf_read.read_exact(buf)?; |
| 364 | } |
| 365 | _ => { |
| 366 | return Err(ProtobufError::WireError(WireError::UnexpectedEof)); |
| 367 | } |
| 368 | } |
| 369 | |
| 370 | self.pos_of_buf_start += buf.len() as u64; |
| 371 | |
| 372 | self.assertions(); |
| 373 | |
| 374 | Ok(()) |
| 375 | } |
| 376 | |
| 377 | fn do_fill_buf(&mut self) -> ProtobufResult<()> { |
| 378 | debug_assert!(self.pos_within_buf == self.limit_within_buf); |
| 379 | |
| 380 | // Limit is reached, do not fill buf, because otherwise |
| 381 | // synchronous read from `CodedInputStream` may block. |
| 382 | if self.limit == self.pos() { |
| 383 | return Ok(()); |
| 384 | } |
| 385 | |
| 386 | let consume = self.buf.len(); |
| 387 | self.pos_of_buf_start += self.buf.len() as u64; |
| 388 | self.buf = &[]; |
| 389 | self.pos_within_buf = 0; |
| 390 | self.limit_within_buf = 0; |
| 391 | |
| 392 | match self.input_source { |
| 393 | InputSource::Read(ref mut buf_read) => { |
| 394 | buf_read.consume(consume); |
| 395 | self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; |
| 396 | } |
| 397 | InputSource::BufRead(ref mut buf_read) => { |
| 398 | buf_read.consume(consume); |
| 399 | self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; |
| 400 | } |
| 401 | _ => { |
| 402 | return Ok(()); |
| 403 | } |
| 404 | } |
| 405 | |
| 406 | self.update_limit_within_buf(); |
| 407 | |
| 408 | Ok(()) |
| 409 | } |
| 410 | |
| 411 | #[inline(always)] |
| 412 | pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> { |
| 413 | if self.pos_within_buf == self.limit_within_buf { |
| 414 | self.do_fill_buf()?; |
| 415 | } |
| 416 | |
| 417 | Ok(if USE_UNSAFE_FOR_SPEED { |
| 418 | unsafe { |
| 419 | self.buf |
| 420 | .get_unchecked(self.pos_within_buf..self.limit_within_buf) |
| 421 | } |
| 422 | } else { |
| 423 | &self.buf[self.pos_within_buf..self.limit_within_buf] |
| 424 | }) |
| 425 | } |
| 426 | |
| 427 | #[inline(always)] |
| 428 | pub fn consume(&mut self, amt: usize) { |
| 429 | assert!(amt <= self.limit_within_buf - self.pos_within_buf); |
| 430 | self.pos_within_buf += amt; |
| 431 | } |
| 432 | } |
| 433 | |
| 434 | #[cfg(all(test, feature = "bytes"))] |
| 435 | mod test_bytes { |
| 436 | use super::*; |
| 437 | use std::io::Write; |
| 438 | |
| 439 | fn make_long_string(len: usize) -> Vec<u8> { |
| 440 | let mut s = Vec::new(); |
| 441 | while s.len() < len { |
| 442 | let len = s.len(); |
| 443 | write!(&mut s, "{}", len).expect("unexpected"); |
| 444 | } |
| 445 | s.truncate(len); |
| 446 | s |
| 447 | } |
| 448 | |
| 449 | #[test] |
| 450 | fn read_exact_bytes_from_slice() { |
| 451 | let bytes = make_long_string(100); |
| 452 | let mut bri = BufReadIter::from_byte_slice(&bytes[..]); |
| 453 | assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]); |
| 454 | assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); |
| 455 | } |
| 456 | |
| 457 | #[test] |
| 458 | fn read_exact_bytes_from_bytes() { |
| 459 | let bytes = Bytes::from(make_long_string(100)); |
| 460 | let mut bri = BufReadIter::from_bytes(&bytes); |
| 461 | let read = bri.read_exact_bytes(90).unwrap(); |
| 462 | assert_eq!(&bytes[..90], &read[..]); |
| 463 | assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr()); |
| 464 | assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); |
| 465 | } |
| 466 | } |
| 467 | |
| 468 | #[cfg(test)] |
| 469 | mod test { |
| 470 | use super::*; |
| 471 | use std::io; |
| 472 | use std::io::BufRead; |
| 473 | use std::io::Read; |
| 474 | |
| 475 | #[test] |
| 476 | fn eof_at_limit() { |
| 477 | struct Read5ThenPanic { |
| 478 | pos: usize, |
| 479 | } |
| 480 | |
| 481 | impl Read for Read5ThenPanic { |
| 482 | fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> { |
| 483 | unreachable!(); |
| 484 | } |
| 485 | } |
| 486 | |
| 487 | impl BufRead for Read5ThenPanic { |
| 488 | fn fill_buf(&mut self) -> io::Result<&[u8]> { |
| 489 | assert_eq!(0, self.pos); |
| 490 | static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4]; |
| 491 | Ok(ZERO_TO_FIVE) |
| 492 | } |
| 493 | |
| 494 | fn consume(&mut self, amt: usize) { |
| 495 | if amt == 0 { |
| 496 | // drop of BufReadIter |
| 497 | return; |
| 498 | } |
| 499 | |
| 500 | assert_eq!(0, self.pos); |
| 501 | assert_eq!(5, amt); |
| 502 | self.pos += amt; |
| 503 | } |
| 504 | } |
| 505 | |
| 506 | let mut read = Read5ThenPanic { pos: 0 }; |
| 507 | let mut buf_read_iter = BufReadIter::from_buf_read(&mut read); |
| 508 | assert_eq!(0, buf_read_iter.pos()); |
| 509 | let _prev_limit = buf_read_iter.push_limit(5); |
| 510 | buf_read_iter.read_byte().expect("read_byte"); |
| 511 | buf_read_iter |
| 512 | .read_exact(&mut [1, 2, 3, 4]) |
| 513 | .expect("read_exact"); |
| 514 | assert!(buf_read_iter.eof().expect("eof")); |
| 515 | } |
| 516 | } |