| Yiming Jing | ebb1872 | 2021-07-16 13:15:12 -0700 | [diff] [blame] | 1 | //! General purpose combinators |
| 2 | |
| 3 | use nom::bytes::streaming::take; |
| 4 | use nom::combinator::map_parser; |
| 5 | pub use nom::error::{make_error, ErrorKind, ParseError}; |
| 6 | pub use nom::{IResult, Needed}; |
| 7 | use nom::{InputIter, InputTake}; |
| 8 | use nom::{InputLength, ToUsize}; |
| 9 | |
| 10 | /// Read the entire slice as a big endian unsigned integer, up to 8 bytes |
| 11 | #[inline] |
| 12 | pub fn be_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> { |
| 13 | if input.is_empty() { |
| 14 | return Err(nom::Err::Incomplete(Needed::new(1))); |
| 15 | } |
| 16 | if input.len() > 8 { |
| 17 | return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge))); |
| 18 | } |
| 19 | let mut res = 0u64; |
| 20 | for byte in input { |
| 21 | res = (res << 8) + *byte as u64; |
| 22 | } |
| 23 | |
| 24 | Ok((&b""[..], res)) |
| 25 | } |
| 26 | |
| 27 | /// Read the entire slice as a little endian unsigned integer, up to 8 bytes |
| 28 | #[inline] |
| 29 | pub fn le_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> { |
| 30 | if input.is_empty() { |
| 31 | return Err(nom::Err::Incomplete(Needed::new(1))); |
| 32 | } |
| 33 | if input.len() > 8 { |
| 34 | return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge))); |
| 35 | } |
| 36 | let mut res = 0u64; |
| 37 | for byte in input.iter().rev() { |
| 38 | res = (res << 8) + *byte as u64; |
| 39 | } |
| 40 | |
| 41 | Ok((&b""[..], res)) |
| 42 | } |
| 43 | |
| 44 | /// Read a slice as a big-endian value. |
| 45 | #[inline] |
| 46 | pub fn parse_hex_to_u64<S>(i: &[u8], size: S) -> IResult<&[u8], u64> |
| 47 | where |
| 48 | S: ToUsize + Copy, |
| 49 | { |
| 50 | map_parser(take(size.to_usize()), be_var_u64)(i) |
| 51 | } |
| 52 | |
| 53 | /// Apply combinator, automatically converts between errors if the underlying type supports it |
| 54 | pub fn upgrade_error<I, O, E1: ParseError<I>, E2: ParseError<I>, F>( |
| 55 | mut f: F, |
| 56 | ) -> impl FnMut(I) -> IResult<I, O, E2> |
| 57 | where |
| 58 | F: FnMut(I) -> IResult<I, O, E1>, |
| 59 | E2: From<E1>, |
| 60 | { |
| 61 | move |i| f(i).map_err(nom::Err::convert) |
| 62 | } |
| 63 | |
| 64 | /// Create a combinator that returns the provided value, and input unchanged |
| 65 | pub fn pure<I, O, E: ParseError<I>>(val: O) -> impl Fn(I) -> IResult<I, O, E> |
| 66 | where |
| 67 | O: Clone, |
| 68 | { |
| 69 | move |input: I| Ok((input, val.clone())) |
| 70 | } |
| 71 | |
| 72 | /// Return a closure that takes `len` bytes from input, and applies `parser`. |
| 73 | pub fn flat_take<I, C, O, E: ParseError<I>, F>(len: C, parser: F) -> impl Fn(I) -> IResult<I, O, E> |
| 74 | where |
| 75 | I: InputTake + InputLength + InputIter, |
| 76 | C: ToUsize + Copy, |
| 77 | F: Fn(I) -> IResult<I, O, E>, |
| 78 | { |
| 79 | // Note: this is the same as `map_parser(take(len), parser)` |
| 80 | move |input: I| { |
| 81 | let (input, o1) = take(len.to_usize())(input)?; |
| 82 | let (_, o2) = parser(o1)?; |
| 83 | Ok((input, o2)) |
| 84 | } |
| 85 | } |
| 86 | |
| 87 | /// Take `len` bytes from `input`, and apply `parser`. |
| 88 | pub fn flat_takec<I: Clone, O, E: ParseError<I>, C, F>( |
| 89 | input: I, |
| 90 | len: C, |
| 91 | parser: F, |
| 92 | ) -> IResult<I, O, E> |
| 93 | where |
| 94 | C: ToUsize + Copy, |
| 95 | F: Fn(I) -> IResult<I, O, E>, |
| 96 | I: InputTake + InputLength + InputIter, |
| 97 | O: InputLength, |
| 98 | { |
| 99 | flat_take(len, parser)(input) |
| 100 | } |
| 101 | |
| 102 | /// Helper macro for nom parsers: run first parser if condition is true, else second parser |
| 103 | pub fn cond_else<I: Clone, O, E: ParseError<I>, C, F, G>( |
| 104 | cond: C, |
| 105 | first: F, |
| 106 | second: G, |
| 107 | ) -> impl Fn(I) -> IResult<I, O, E> |
| 108 | where |
| 109 | C: Fn() -> bool, |
| 110 | F: Fn(I) -> IResult<I, O, E>, |
| 111 | G: Fn(I) -> IResult<I, O, E>, |
| 112 | { |
| 113 | move |input: I| { |
| 114 | if cond() { |
| 115 | first(input) |
| 116 | } else { |
| 117 | second(input) |
| 118 | } |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | /// Align input value to the next multiple of n bytes |
| 123 | /// Valid only if n is a power of 2 |
| 124 | pub const fn align_n2(x: usize, n: usize) -> usize { |
| 125 | (x + (n - 1)) & !(n - 1) |
| 126 | } |
| 127 | |
| 128 | /// Align input value to the next multiple of 4 bytes |
| 129 | pub const fn align32(x: usize) -> usize { |
| 130 | (x + 3) & !3 |
| 131 | } |
| 132 | |
| 133 | #[cfg(test)] |
| 134 | mod tests { |
| 135 | use super::{align32, be_var_u64, cond_else, flat_take, pure}; |
| 136 | use nom::bytes::streaming::take; |
| 137 | use nom::number::streaming::{be_u16, be_u32, be_u8}; |
| 138 | use nom::{Err, IResult, Needed}; |
| 139 | |
| 140 | #[test] |
| 141 | fn test_be_var_u64() { |
| 142 | let res: IResult<&[u8], u64> = be_var_u64(b"\x12\x34\x56"); |
| 143 | let (_, v) = res.expect("be_var_u64 failed"); |
| 144 | assert_eq!(v, 0x123456); |
| 145 | } |
| 146 | |
| 147 | #[test] |
| 148 | fn test_flat_take() { |
| 149 | let input = &[0x00, 0x01, 0xff]; |
| 150 | // read first 2 bytes and use correct combinator: OK |
| 151 | let res: IResult<&[u8], u16> = flat_take(2u8, be_u16)(input); |
| 152 | assert_eq!(res, Ok((&input[2..], 0x0001))); |
| 153 | // read 3 bytes and use 2: OK (some input is just lost) |
| 154 | let res: IResult<&[u8], u16> = flat_take(3u8, be_u16)(input); |
| 155 | assert_eq!(res, Ok((&b""[..], 0x0001))); |
| 156 | // read 2 bytes and a combinator requiring more bytes |
| 157 | let res: IResult<&[u8], u32> = flat_take(2u8, be_u32)(input); |
| 158 | assert_eq!(res, Err(Err::Incomplete(Needed::new(2)))); |
| 159 | } |
| 160 | |
| 161 | #[test] |
| 162 | fn test_flat_take_str() { |
| 163 | let input = "abcdef"; |
| 164 | // read first 2 bytes and use correct combinator: OK |
| 165 | let res: IResult<&str, &str> = flat_take(2u8, take(2u8))(input); |
| 166 | assert_eq!(res, Ok(("cdef", "ab"))); |
| 167 | // read 3 bytes and use 2: OK (some input is just lost) |
| 168 | let res: IResult<&str, &str> = flat_take(3u8, take(2u8))(input); |
| 169 | assert_eq!(res, Ok(("def", "ab"))); |
| 170 | // read 2 bytes and a use combinator requiring more bytes |
| 171 | let res: IResult<&str, &str> = flat_take(2u8, take(4u8))(input); |
| 172 | assert_eq!(res, Err(Err::Incomplete(Needed::Unknown))); |
| 173 | } |
| 174 | |
| 175 | #[test] |
| 176 | fn test_cond_else() { |
| 177 | let input = &[0x01][..]; |
| 178 | let empty = &b""[..]; |
| 179 | let a = 1; |
| 180 | fn parse_u8(i: &[u8]) -> IResult<&[u8], u8> { |
| 181 | be_u8(i) |
| 182 | } |
| 183 | assert_eq!( |
| 184 | cond_else(|| a == 1, parse_u8, pure(0x02))(input), |
| 185 | Ok((empty, 0x01)) |
| 186 | ); |
| 187 | assert_eq!( |
| 188 | cond_else(|| a == 1, parse_u8, pure(0x02))(input), |
| 189 | Ok((empty, 0x01)) |
| 190 | ); |
| 191 | assert_eq!( |
| 192 | cond_else(|| a == 2, parse_u8, pure(0x02))(input), |
| 193 | Ok((input, 0x02)) |
| 194 | ); |
| 195 | assert_eq!( |
| 196 | cond_else(|| a == 1, pure(0x02), parse_u8)(input), |
| 197 | Ok((input, 0x02)) |
| 198 | ); |
| 199 | let res: IResult<&[u8], u8> = cond_else(|| a == 1, parse_u8, parse_u8)(input); |
| 200 | assert_eq!(res, Ok((empty, 0x01))); |
| 201 | } |
| 202 | |
| 203 | #[test] |
| 204 | fn test_align32() { |
| 205 | assert_eq!(align32(3), 4); |
| 206 | assert_eq!(align32(4), 4); |
| 207 | assert_eq!(align32(5), 8); |
| 208 | assert_eq!(align32(5usize), 8); |
| 209 | } |
| 210 | } |