Yi Kong | c1a6315 | 2021-02-03 15:04:59 +0800 | [diff] [blame^] | 1 | use core; |
| 2 | use core::mem; |
| 3 | use traits::checked_pow; |
| 4 | use traits::PrimInt; |
| 5 | use Integer; |
| 6 | |
| 7 | /// Provides methods to compute an integer's square root, cube root, |
| 8 | /// and arbitrary `n`th root. |
| 9 | pub trait Roots: Integer { |
| 10 | /// Returns the truncated principal `n`th root of an integer |
| 11 | /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }` |
| 12 | /// |
| 13 | /// This is solving for `r` in `rⁿ = x`, rounding toward zero. |
| 14 | /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`. |
| 15 | /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`. |
| 16 | /// |
| 17 | /// # Panics |
| 18 | /// |
| 19 | /// Panics if `n` is zero: |
| 20 | /// |
| 21 | /// ```should_panic |
| 22 | /// # use num_integer::Roots; |
| 23 | /// println!("can't compute ⁰√x : {}", 123.nth_root(0)); |
| 24 | /// ``` |
| 25 | /// |
| 26 | /// or if `n` is even and `self` is negative: |
| 27 | /// |
| 28 | /// ```should_panic |
| 29 | /// # use num_integer::Roots; |
| 30 | /// println!("no imaginary numbers... {}", (-1).nth_root(10)); |
| 31 | /// ``` |
| 32 | /// |
| 33 | /// # Examples |
| 34 | /// |
| 35 | /// ``` |
| 36 | /// use num_integer::Roots; |
| 37 | /// |
| 38 | /// let x: i32 = 12345; |
| 39 | /// assert_eq!(x.nth_root(1), x); |
| 40 | /// assert_eq!(x.nth_root(2), x.sqrt()); |
| 41 | /// assert_eq!(x.nth_root(3), x.cbrt()); |
| 42 | /// assert_eq!(x.nth_root(4), 10); |
| 43 | /// assert_eq!(x.nth_root(13), 2); |
| 44 | /// assert_eq!(x.nth_root(14), 1); |
| 45 | /// assert_eq!(x.nth_root(std::u32::MAX), 1); |
| 46 | /// |
| 47 | /// assert_eq!(std::i32::MAX.nth_root(30), 2); |
| 48 | /// assert_eq!(std::i32::MAX.nth_root(31), 1); |
| 49 | /// assert_eq!(std::i32::MIN.nth_root(31), -2); |
| 50 | /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1); |
| 51 | /// |
| 52 | /// assert_eq!(std::u32::MAX.nth_root(31), 2); |
| 53 | /// assert_eq!(std::u32::MAX.nth_root(32), 1); |
| 54 | /// ``` |
| 55 | fn nth_root(&self, n: u32) -> Self; |
| 56 | |
| 57 | /// Returns the truncated principal square root of an integer -- `⌊√x⌋` |
| 58 | /// |
| 59 | /// This is solving for `r` in `r² = x`, rounding toward zero. |
| 60 | /// The result will satisfy `r² ≤ x < (r+1)²`. |
| 61 | /// |
| 62 | /// # Panics |
| 63 | /// |
| 64 | /// Panics if `self` is less than zero: |
| 65 | /// |
| 66 | /// ```should_panic |
| 67 | /// # use num_integer::Roots; |
| 68 | /// println!("no imaginary numbers... {}", (-1).sqrt()); |
| 69 | /// ``` |
| 70 | /// |
| 71 | /// # Examples |
| 72 | /// |
| 73 | /// ``` |
| 74 | /// use num_integer::Roots; |
| 75 | /// |
| 76 | /// let x: i32 = 12345; |
| 77 | /// assert_eq!((x * x).sqrt(), x); |
| 78 | /// assert_eq!((x * x + 1).sqrt(), x); |
| 79 | /// assert_eq!((x * x - 1).sqrt(), x - 1); |
| 80 | /// ``` |
| 81 | #[inline] |
| 82 | fn sqrt(&self) -> Self { |
| 83 | self.nth_root(2) |
| 84 | } |
| 85 | |
| 86 | /// Returns the truncated principal cube root of an integer -- |
| 87 | /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }` |
| 88 | /// |
| 89 | /// This is solving for `r` in `r³ = x`, rounding toward zero. |
| 90 | /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`. |
| 91 | /// If `x` is negative, then `(r-1)³ < x ≤ r³`. |
| 92 | /// |
| 93 | /// # Examples |
| 94 | /// |
| 95 | /// ``` |
| 96 | /// use num_integer::Roots; |
| 97 | /// |
| 98 | /// let x: i32 = 1234; |
| 99 | /// assert_eq!((x * x * x).cbrt(), x); |
| 100 | /// assert_eq!((x * x * x + 1).cbrt(), x); |
| 101 | /// assert_eq!((x * x * x - 1).cbrt(), x - 1); |
| 102 | /// |
| 103 | /// assert_eq!((-(x * x * x)).cbrt(), -x); |
| 104 | /// assert_eq!((-(x * x * x + 1)).cbrt(), -x); |
| 105 | /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1)); |
| 106 | /// ``` |
| 107 | #[inline] |
| 108 | fn cbrt(&self) -> Self { |
| 109 | self.nth_root(3) |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | /// Returns the truncated principal square root of an integer -- |
| 114 | /// see [Roots::sqrt](trait.Roots.html#method.sqrt). |
| 115 | #[inline] |
| 116 | pub fn sqrt<T: Roots>(x: T) -> T { |
| 117 | x.sqrt() |
| 118 | } |
| 119 | |
| 120 | /// Returns the truncated principal cube root of an integer -- |
| 121 | /// see [Roots::cbrt](trait.Roots.html#method.cbrt). |
| 122 | #[inline] |
| 123 | pub fn cbrt<T: Roots>(x: T) -> T { |
| 124 | x.cbrt() |
| 125 | } |
| 126 | |
| 127 | /// Returns the truncated principal `n`th root of an integer -- |
| 128 | /// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root). |
| 129 | #[inline] |
| 130 | pub fn nth_root<T: Roots>(x: T, n: u32) -> T { |
| 131 | x.nth_root(n) |
| 132 | } |
| 133 | |
| 134 | macro_rules! signed_roots { |
| 135 | ($T:ty, $U:ty) => { |
| 136 | impl Roots for $T { |
| 137 | #[inline] |
| 138 | fn nth_root(&self, n: u32) -> Self { |
| 139 | if *self >= 0 { |
| 140 | (*self as $U).nth_root(n) as Self |
| 141 | } else { |
| 142 | assert!(n.is_odd(), "even roots of a negative are imaginary"); |
| 143 | -((self.wrapping_neg() as $U).nth_root(n) as Self) |
| 144 | } |
| 145 | } |
| 146 | |
| 147 | #[inline] |
| 148 | fn sqrt(&self) -> Self { |
| 149 | assert!(*self >= 0, "the square root of a negative is imaginary"); |
| 150 | (*self as $U).sqrt() as Self |
| 151 | } |
| 152 | |
| 153 | #[inline] |
| 154 | fn cbrt(&self) -> Self { |
| 155 | if *self >= 0 { |
| 156 | (*self as $U).cbrt() as Self |
| 157 | } else { |
| 158 | -((self.wrapping_neg() as $U).cbrt() as Self) |
| 159 | } |
| 160 | } |
| 161 | } |
| 162 | }; |
| 163 | } |
| 164 | |
| 165 | signed_roots!(i8, u8); |
| 166 | signed_roots!(i16, u16); |
| 167 | signed_roots!(i32, u32); |
| 168 | signed_roots!(i64, u64); |
| 169 | #[cfg(has_i128)] |
| 170 | signed_roots!(i128, u128); |
| 171 | signed_roots!(isize, usize); |
| 172 | |
| 173 | #[inline] |
| 174 | fn fixpoint<T, F>(mut x: T, f: F) -> T |
| 175 | where |
| 176 | T: Integer + Copy, |
| 177 | F: Fn(T) -> T, |
| 178 | { |
| 179 | let mut xn = f(x); |
| 180 | while x < xn { |
| 181 | x = xn; |
| 182 | xn = f(x); |
| 183 | } |
| 184 | while x > xn { |
| 185 | x = xn; |
| 186 | xn = f(x); |
| 187 | } |
| 188 | x |
| 189 | } |
| 190 | |
| 191 | #[inline] |
| 192 | fn bits<T>() -> u32 { |
| 193 | 8 * mem::size_of::<T>() as u32 |
| 194 | } |
| 195 | |
| 196 | #[inline] |
| 197 | fn log2<T: PrimInt>(x: T) -> u32 { |
| 198 | debug_assert!(x > T::zero()); |
| 199 | bits::<T>() - 1 - x.leading_zeros() |
| 200 | } |
| 201 | |
| 202 | macro_rules! unsigned_roots { |
| 203 | ($T:ident) => { |
| 204 | impl Roots for $T { |
| 205 | #[inline] |
| 206 | fn nth_root(&self, n: u32) -> Self { |
| 207 | fn go(a: $T, n: u32) -> $T { |
| 208 | // Specialize small roots |
| 209 | match n { |
| 210 | 0 => panic!("can't find a root of degree 0!"), |
| 211 | 1 => return a, |
| 212 | 2 => return a.sqrt(), |
| 213 | 3 => return a.cbrt(), |
| 214 | _ => (), |
| 215 | } |
| 216 | |
| 217 | // The root of values less than 2ⁿ can only be 0 or 1. |
| 218 | if bits::<$T>() <= n || a < (1 << n) { |
| 219 | return (a > 0) as $T; |
| 220 | } |
| 221 | |
| 222 | if bits::<$T>() > 64 { |
| 223 | // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough. |
| 224 | return if a <= core::u64::MAX as $T { |
| 225 | (a as u64).nth_root(n) as $T |
| 226 | } else { |
| 227 | let lo = (a >> n).nth_root(n) << 1; |
| 228 | let hi = lo + 1; |
| 229 | // 128-bit `checked_mul` also involves division, but we can't always |
| 230 | // compute `hiⁿ` without risking overflow. Try to avoid it though... |
| 231 | if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() { |
| 232 | match checked_pow(hi, n as usize) { |
| 233 | Some(x) if x <= a => hi, |
| 234 | _ => lo, |
| 235 | } |
| 236 | } else { |
| 237 | if hi.pow(n) <= a { |
| 238 | hi |
| 239 | } else { |
| 240 | lo |
| 241 | } |
| 242 | } |
| 243 | }; |
| 244 | } |
| 245 | |
| 246 | #[cfg(feature = "std")] |
| 247 | #[inline] |
| 248 | fn guess(x: $T, n: u32) -> $T { |
| 249 | // for smaller inputs, `f64` doesn't justify its cost. |
| 250 | if bits::<$T>() <= 32 || x <= core::u32::MAX as $T { |
| 251 | 1 << ((log2(x) + n - 1) / n) |
| 252 | } else { |
| 253 | ((x as f64).ln() / f64::from(n)).exp() as $T |
| 254 | } |
| 255 | } |
| 256 | |
| 257 | #[cfg(not(feature = "std"))] |
| 258 | #[inline] |
| 259 | fn guess(x: $T, n: u32) -> $T { |
| 260 | 1 << ((log2(x) + n - 1) / n) |
| 261 | } |
| 262 | |
| 263 | // https://en.wikipedia.org/wiki/Nth_root_algorithm |
| 264 | let n1 = n - 1; |
| 265 | let next = |x: $T| { |
| 266 | let y = match checked_pow(x, n1 as usize) { |
| 267 | Some(ax) => a / ax, |
| 268 | None => 0, |
| 269 | }; |
| 270 | (y + x * n1 as $T) / n as $T |
| 271 | }; |
| 272 | fixpoint(guess(a, n), next) |
| 273 | } |
| 274 | go(*self, n) |
| 275 | } |
| 276 | |
| 277 | #[inline] |
| 278 | fn sqrt(&self) -> Self { |
| 279 | fn go(a: $T) -> $T { |
| 280 | if bits::<$T>() > 64 { |
| 281 | // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough. |
| 282 | return if a <= core::u64::MAX as $T { |
| 283 | (a as u64).sqrt() as $T |
| 284 | } else { |
| 285 | let lo = (a >> 2u32).sqrt() << 1; |
| 286 | let hi = lo + 1; |
| 287 | if hi * hi <= a { |
| 288 | hi |
| 289 | } else { |
| 290 | lo |
| 291 | } |
| 292 | }; |
| 293 | } |
| 294 | |
| 295 | if a < 4 { |
| 296 | return (a > 0) as $T; |
| 297 | } |
| 298 | |
| 299 | #[cfg(feature = "std")] |
| 300 | #[inline] |
| 301 | fn guess(x: $T) -> $T { |
| 302 | (x as f64).sqrt() as $T |
| 303 | } |
| 304 | |
| 305 | #[cfg(not(feature = "std"))] |
| 306 | #[inline] |
| 307 | fn guess(x: $T) -> $T { |
| 308 | 1 << ((log2(x) + 1) / 2) |
| 309 | } |
| 310 | |
| 311 | // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method |
| 312 | let next = |x: $T| (a / x + x) >> 1; |
| 313 | fixpoint(guess(a), next) |
| 314 | } |
| 315 | go(*self) |
| 316 | } |
| 317 | |
| 318 | #[inline] |
| 319 | fn cbrt(&self) -> Self { |
| 320 | fn go(a: $T) -> $T { |
| 321 | if bits::<$T>() > 64 { |
| 322 | // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough. |
| 323 | return if a <= core::u64::MAX as $T { |
| 324 | (a as u64).cbrt() as $T |
| 325 | } else { |
| 326 | let lo = (a >> 3u32).cbrt() << 1; |
| 327 | let hi = lo + 1; |
| 328 | if hi * hi * hi <= a { |
| 329 | hi |
| 330 | } else { |
| 331 | lo |
| 332 | } |
| 333 | }; |
| 334 | } |
| 335 | |
| 336 | if bits::<$T>() <= 32 { |
| 337 | // Implementation based on Hacker's Delight `icbrt2` |
| 338 | let mut x = a; |
| 339 | let mut y2 = 0; |
| 340 | let mut y = 0; |
| 341 | let smax = bits::<$T>() / 3; |
| 342 | for s in (0..smax + 1).rev() { |
| 343 | let s = s * 3; |
| 344 | y2 *= 4; |
| 345 | y *= 2; |
| 346 | let b = 3 * (y2 + y) + 1; |
| 347 | if x >> s >= b { |
| 348 | x -= b << s; |
| 349 | y2 += 2 * y + 1; |
| 350 | y += 1; |
| 351 | } |
| 352 | } |
| 353 | return y; |
| 354 | } |
| 355 | |
| 356 | if a < 8 { |
| 357 | return (a > 0) as $T; |
| 358 | } |
| 359 | if a <= core::u32::MAX as $T { |
| 360 | return (a as u32).cbrt() as $T; |
| 361 | } |
| 362 | |
| 363 | #[cfg(feature = "std")] |
| 364 | #[inline] |
| 365 | fn guess(x: $T) -> $T { |
| 366 | (x as f64).cbrt() as $T |
| 367 | } |
| 368 | |
| 369 | #[cfg(not(feature = "std"))] |
| 370 | #[inline] |
| 371 | fn guess(x: $T) -> $T { |
| 372 | 1 << ((log2(x) + 2) / 3) |
| 373 | } |
| 374 | |
| 375 | // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods |
| 376 | let next = |x: $T| (a / (x * x) + x * 2) / 3; |
| 377 | fixpoint(guess(a), next) |
| 378 | } |
| 379 | go(*self) |
| 380 | } |
| 381 | } |
| 382 | }; |
| 383 | } |
| 384 | |
| 385 | unsigned_roots!(u8); |
| 386 | unsigned_roots!(u16); |
| 387 | unsigned_roots!(u32); |
| 388 | unsigned_roots!(u64); |
| 389 | #[cfg(has_i128)] |
| 390 | unsigned_roots!(u128); |
| 391 | unsigned_roots!(usize); |