blob: f5de774279bd026aae25b6393d7d0d5768fb0082 [file] [log] [blame]
Yiming Jingebb18722021-07-16 13:15:12 -07001//! General purpose combinators
2
3use nom::bytes::streaming::take;
4use nom::combinator::map_parser;
5pub use nom::error::{make_error, ErrorKind, ParseError};
6pub use nom::{IResult, Needed};
7use nom::{InputIter, InputTake};
8use nom::{InputLength, ToUsize};
9
10/// Read the entire slice as a big endian unsigned integer, up to 8 bytes
11#[inline]
12pub fn be_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> {
13 if input.is_empty() {
14 return Err(nom::Err::Incomplete(Needed::new(1)));
15 }
16 if input.len() > 8 {
17 return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge)));
18 }
19 let mut res = 0u64;
20 for byte in input {
21 res = (res << 8) + *byte as u64;
22 }
23
24 Ok((&b""[..], res))
25}
26
27/// Read the entire slice as a little endian unsigned integer, up to 8 bytes
28#[inline]
29pub fn le_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> {
30 if input.is_empty() {
31 return Err(nom::Err::Incomplete(Needed::new(1)));
32 }
33 if input.len() > 8 {
34 return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge)));
35 }
36 let mut res = 0u64;
37 for byte in input.iter().rev() {
38 res = (res << 8) + *byte as u64;
39 }
40
41 Ok((&b""[..], res))
42}
43
44/// Read a slice as a big-endian value.
45#[inline]
46pub fn parse_hex_to_u64<S>(i: &[u8], size: S) -> IResult<&[u8], u64>
47where
48 S: ToUsize + Copy,
49{
50 map_parser(take(size.to_usize()), be_var_u64)(i)
51}
52
53/// Apply combinator, automatically converts between errors if the underlying type supports it
54pub fn upgrade_error<I, O, E1: ParseError<I>, E2: ParseError<I>, F>(
55 mut f: F,
56) -> impl FnMut(I) -> IResult<I, O, E2>
57where
58 F: FnMut(I) -> IResult<I, O, E1>,
59 E2: From<E1>,
60{
61 move |i| f(i).map_err(nom::Err::convert)
62}
63
64/// Create a combinator that returns the provided value, and input unchanged
65pub fn pure<I, O, E: ParseError<I>>(val: O) -> impl Fn(I) -> IResult<I, O, E>
66where
67 O: Clone,
68{
69 move |input: I| Ok((input, val.clone()))
70}
71
72/// Return a closure that takes `len` bytes from input, and applies `parser`.
73pub fn flat_take<I, C, O, E: ParseError<I>, F>(len: C, parser: F) -> impl Fn(I) -> IResult<I, O, E>
74where
75 I: InputTake + InputLength + InputIter,
76 C: ToUsize + Copy,
77 F: Fn(I) -> IResult<I, O, E>,
78{
79 // Note: this is the same as `map_parser(take(len), parser)`
80 move |input: I| {
81 let (input, o1) = take(len.to_usize())(input)?;
82 let (_, o2) = parser(o1)?;
83 Ok((input, o2))
84 }
85}
86
87/// Take `len` bytes from `input`, and apply `parser`.
88pub fn flat_takec<I: Clone, O, E: ParseError<I>, C, F>(
89 input: I,
90 len: C,
91 parser: F,
92) -> IResult<I, O, E>
93where
94 C: ToUsize + Copy,
95 F: Fn(I) -> IResult<I, O, E>,
96 I: InputTake + InputLength + InputIter,
97 O: InputLength,
98{
99 flat_take(len, parser)(input)
100}
101
102/// Helper macro for nom parsers: run first parser if condition is true, else second parser
103pub fn cond_else<I: Clone, O, E: ParseError<I>, C, F, G>(
104 cond: C,
105 first: F,
106 second: G,
107) -> impl Fn(I) -> IResult<I, O, E>
108where
109 C: Fn() -> bool,
110 F: Fn(I) -> IResult<I, O, E>,
111 G: Fn(I) -> IResult<I, O, E>,
112{
113 move |input: I| {
114 if cond() {
115 first(input)
116 } else {
117 second(input)
118 }
119 }
120}
121
122/// Align input value to the next multiple of n bytes
123/// Valid only if n is a power of 2
124pub const fn align_n2(x: usize, n: usize) -> usize {
125 (x + (n - 1)) & !(n - 1)
126}
127
128/// Align input value to the next multiple of 4 bytes
129pub const fn align32(x: usize) -> usize {
130 (x + 3) & !3
131}
132
133#[cfg(test)]
134mod tests {
135 use super::{align32, be_var_u64, cond_else, flat_take, pure};
136 use nom::bytes::streaming::take;
137 use nom::number::streaming::{be_u16, be_u32, be_u8};
138 use nom::{Err, IResult, Needed};
139
140 #[test]
141 fn test_be_var_u64() {
142 let res: IResult<&[u8], u64> = be_var_u64(b"\x12\x34\x56");
143 let (_, v) = res.expect("be_var_u64 failed");
144 assert_eq!(v, 0x123456);
145 }
146
147 #[test]
148 fn test_flat_take() {
149 let input = &[0x00, 0x01, 0xff];
150 // read first 2 bytes and use correct combinator: OK
151 let res: IResult<&[u8], u16> = flat_take(2u8, be_u16)(input);
152 assert_eq!(res, Ok((&input[2..], 0x0001)));
153 // read 3 bytes and use 2: OK (some input is just lost)
154 let res: IResult<&[u8], u16> = flat_take(3u8, be_u16)(input);
155 assert_eq!(res, Ok((&b""[..], 0x0001)));
156 // read 2 bytes and a combinator requiring more bytes
157 let res: IResult<&[u8], u32> = flat_take(2u8, be_u32)(input);
158 assert_eq!(res, Err(Err::Incomplete(Needed::new(2))));
159 }
160
161 #[test]
162 fn test_flat_take_str() {
163 let input = "abcdef";
164 // read first 2 bytes and use correct combinator: OK
165 let res: IResult<&str, &str> = flat_take(2u8, take(2u8))(input);
166 assert_eq!(res, Ok(("cdef", "ab")));
167 // read 3 bytes and use 2: OK (some input is just lost)
168 let res: IResult<&str, &str> = flat_take(3u8, take(2u8))(input);
169 assert_eq!(res, Ok(("def", "ab")));
170 // read 2 bytes and a use combinator requiring more bytes
171 let res: IResult<&str, &str> = flat_take(2u8, take(4u8))(input);
172 assert_eq!(res, Err(Err::Incomplete(Needed::Unknown)));
173 }
174
175 #[test]
176 fn test_cond_else() {
177 let input = &[0x01][..];
178 let empty = &b""[..];
179 let a = 1;
180 fn parse_u8(i: &[u8]) -> IResult<&[u8], u8> {
181 be_u8(i)
182 }
183 assert_eq!(
184 cond_else(|| a == 1, parse_u8, pure(0x02))(input),
185 Ok((empty, 0x01))
186 );
187 assert_eq!(
188 cond_else(|| a == 1, parse_u8, pure(0x02))(input),
189 Ok((empty, 0x01))
190 );
191 assert_eq!(
192 cond_else(|| a == 2, parse_u8, pure(0x02))(input),
193 Ok((input, 0x02))
194 );
195 assert_eq!(
196 cond_else(|| a == 1, pure(0x02), parse_u8)(input),
197 Ok((input, 0x02))
198 );
199 let res: IResult<&[u8], u8> = cond_else(|| a == 1, parse_u8, parse_u8)(input);
200 assert_eq!(res, Ok((empty, 0x01)));
201 }
202
203 #[test]
204 fn test_align32() {
205 assert_eq!(align32(3), 4);
206 assert_eq!(align32(4), 4);
207 assert_eq!(align32(5), 8);
208 assert_eq!(align32(5usize), 8);
209 }
210}