blob: d24651bf482e1ef5377d1a5235e13afbdcecec73 [file] [log] [blame]
Yiming Jingcf21fc42021-07-16 13:23:26 -07001use super::monty::monty_modpow;
2use super::BigUint;
3
4use crate::big_digit::{self, BigDigit};
5
6use num_integer::Integer;
7use num_traits::{One, Pow, ToPrimitive, Zero};
8
9impl<'b> Pow<&'b BigUint> for BigUint {
10 type Output = BigUint;
11
12 #[inline]
13 fn pow(self, exp: &BigUint) -> BigUint {
14 if self.is_one() || exp.is_zero() {
15 BigUint::one()
16 } else if self.is_zero() {
17 BigUint::zero()
18 } else if let Some(exp) = exp.to_u64() {
19 self.pow(exp)
20 } else if let Some(exp) = exp.to_u128() {
21 self.pow(exp)
22 } else {
23 // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given
24 // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address!
25 panic!("memory overflow")
26 }
27 }
28}
29
30impl Pow<BigUint> for BigUint {
31 type Output = BigUint;
32
33 #[inline]
34 fn pow(self, exp: BigUint) -> BigUint {
35 Pow::pow(self, &exp)
36 }
37}
38
39impl<'a, 'b> Pow<&'b BigUint> for &'a BigUint {
40 type Output = BigUint;
41
42 #[inline]
43 fn pow(self, exp: &BigUint) -> BigUint {
44 if self.is_one() || exp.is_zero() {
45 BigUint::one()
46 } else if self.is_zero() {
47 BigUint::zero()
48 } else {
49 self.clone().pow(exp)
50 }
51 }
52}
53
54impl<'a> Pow<BigUint> for &'a BigUint {
55 type Output = BigUint;
56
57 #[inline]
58 fn pow(self, exp: BigUint) -> BigUint {
59 Pow::pow(self, &exp)
60 }
61}
62
63macro_rules! pow_impl {
64 ($T:ty) => {
65 impl Pow<$T> for BigUint {
66 type Output = BigUint;
67
68 fn pow(self, mut exp: $T) -> BigUint {
69 if exp == 0 {
70 return BigUint::one();
71 }
72 let mut base = self;
73
74 while exp & 1 == 0 {
75 base = &base * &base;
76 exp >>= 1;
77 }
78
79 if exp == 1 {
80 return base;
81 }
82
83 let mut acc = base.clone();
84 while exp > 1 {
85 exp >>= 1;
86 base = &base * &base;
87 if exp & 1 == 1 {
Joel Galenson7bace412021-09-22 14:05:35 -070088 acc *= &base;
Yiming Jingcf21fc42021-07-16 13:23:26 -070089 }
90 }
91 acc
92 }
93 }
94
95 impl<'b> Pow<&'b $T> for BigUint {
96 type Output = BigUint;
97
98 #[inline]
99 fn pow(self, exp: &$T) -> BigUint {
100 Pow::pow(self, *exp)
101 }
102 }
103
104 impl<'a> Pow<$T> for &'a BigUint {
105 type Output = BigUint;
106
107 #[inline]
108 fn pow(self, exp: $T) -> BigUint {
109 if exp == 0 {
110 return BigUint::one();
111 }
112 Pow::pow(self.clone(), exp)
113 }
114 }
115
116 impl<'a, 'b> Pow<&'b $T> for &'a BigUint {
117 type Output = BigUint;
118
119 #[inline]
120 fn pow(self, exp: &$T) -> BigUint {
121 Pow::pow(self, *exp)
122 }
123 }
124 };
125}
126
127pow_impl!(u8);
128pow_impl!(u16);
129pow_impl!(u32);
130pow_impl!(u64);
131pow_impl!(usize);
132pow_impl!(u128);
133
134pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
135 assert!(
136 !modulus.is_zero(),
137 "attempt to calculate with zero modulus!"
138 );
139
140 if modulus.is_odd() {
141 // For an odd modulus, we can use Montgomery multiplication in base 2^32.
142 monty_modpow(x, exponent, modulus)
143 } else {
144 // Otherwise do basically the same as `num::pow`, but with a modulus.
145 plain_modpow(x, &exponent.data, modulus)
146 }
147}
148
149fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {
150 assert!(
151 !modulus.is_zero(),
152 "attempt to calculate with zero modulus!"
153 );
154
155 let i = match exp_data.iter().position(|&r| r != 0) {
156 None => return BigUint::one(),
157 Some(i) => i,
158 };
159
160 let mut base = base % modulus;
161 for _ in 0..i {
162 for _ in 0..big_digit::BITS {
163 base = &base * &base % modulus;
164 }
165 }
166
167 let mut r = exp_data[i];
168 let mut b = 0u8;
169 while r.is_even() {
170 base = &base * &base % modulus;
171 r >>= 1;
172 b += 1;
173 }
174
175 let mut exp_iter = exp_data[i + 1..].iter();
176 if exp_iter.len() == 0 && r.is_one() {
177 return base;
178 }
179
180 let mut acc = base.clone();
181 r >>= 1;
182 b += 1;
183
184 {
185 let mut unit = |exp_is_odd| {
186 base = &base * &base % modulus;
187 if exp_is_odd {
Joel Galenson7bace412021-09-22 14:05:35 -0700188 acc *= &base;
189 acc %= modulus;
Yiming Jingcf21fc42021-07-16 13:23:26 -0700190 }
191 };
192
193 if let Some(&last) = exp_iter.next_back() {
194 // consume exp_data[i]
195 for _ in b..big_digit::BITS {
196 unit(r.is_odd());
197 r >>= 1;
198 }
199
200 // consume all other digits before the last
201 for &r in exp_iter {
202 let mut r = r;
203 for _ in 0..big_digit::BITS {
204 unit(r.is_odd());
205 r >>= 1;
206 }
207 }
208 r = last;
209 }
210
211 debug_assert_ne!(r, 0);
212 while !r.is_zero() {
213 unit(r.is_odd());
214 r >>= 1;
215 }
216 }
217 acc
218}
219
220#[test]
221fn test_plain_modpow() {
222 let two = &BigUint::from(2u32);
223 let modulus = BigUint::from(0x1100u32);
224
225 let exp = vec![0, 0b1];
226 assert_eq!(
227 two.pow(0b1_00000000_u32) % &modulus,
228 plain_modpow(&two, &exp, &modulus)
229 );
230 let exp = vec![0, 0b10];
231 assert_eq!(
232 two.pow(0b10_00000000_u32) % &modulus,
233 plain_modpow(&two, &exp, &modulus)
234 );
235 let exp = vec![0, 0b110010];
236 assert_eq!(
237 two.pow(0b110010_00000000_u32) % &modulus,
238 plain_modpow(&two, &exp, &modulus)
239 );
240 let exp = vec![0b1, 0b1];
241 assert_eq!(
242 two.pow(0b1_00000001_u32) % &modulus,
243 plain_modpow(&two, &exp, &modulus)
244 );
245 let exp = vec![0b1100, 0, 0b1];
246 assert_eq!(
247 two.pow(0b1_00000000_00001100_u32) % &modulus,
248 plain_modpow(&two, &exp, &modulus)
249 );
250}
251
252#[test]
253fn test_pow_biguint() {
254 let base = BigUint::from(5u8);
255 let exponent = BigUint::from(3u8);
256
257 assert_eq!(BigUint::from(125u8), base.pow(exponent));
258}