blob: cfef80c47ba956d45c212f88b9a4955be0301ac1 [file] [log] [blame]
Yiming Jingcf21fc42021-07-16 13:23:26 -07001mod biguint {
2 use num_bigint::BigUint;
3 use num_traits::{One, Zero};
4 use std::{i32, u32};
5
6 fn check<T: Into<BigUint>>(x: T, n: u32) {
7 let x: BigUint = x.into();
8 let root = x.nth_root(n);
9 println!("check {}.nth_root({}) = {}", x, n, root);
10
11 if n == 2 {
12 assert_eq!(root, x.sqrt())
13 } else if n == 3 {
14 assert_eq!(root, x.cbrt())
15 }
16
17 let lo = root.pow(n);
18 assert!(lo <= x);
19 assert_eq!(lo.nth_root(n), root);
20 if !lo.is_zero() {
21 assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
22 }
23
24 let hi = (&root + 1u32).pow(n);
25 assert!(hi > x);
26 assert_eq!(hi.nth_root(n), &root + 1u32);
27 assert_eq!((&hi - 1u32).nth_root(n), root);
28 }
29
30 #[test]
31 fn test_sqrt() {
32 check(99u32, 2);
33 check(100u32, 2);
34 check(120u32, 2);
35 }
36
37 #[test]
38 fn test_cbrt() {
39 check(8u32, 3);
40 check(26u32, 3);
41 }
42
43 #[test]
44 fn test_nth_root() {
45 check(0u32, 1);
46 check(10u32, 1);
47 check(100u32, 4);
48 }
49
50 #[test]
51 #[should_panic]
52 fn test_nth_root_n_is_zero() {
53 check(4u32, 0);
54 }
55
56 #[test]
57 fn test_nth_root_big() {
58 let x = BigUint::from(123_456_789_u32);
59 let expected = BigUint::from(6u32);
60
61 assert_eq!(x.nth_root(10), expected);
62 check(x, 10);
63 }
64
65 #[test]
66 fn test_nth_root_googol() {
67 let googol = BigUint::from(10u32).pow(100u32);
68
69 // perfect divisors of 100
70 for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
71 let expected = BigUint::from(10u32).pow(100u32 / n);
72 assert_eq!(googol.nth_root(n), expected);
73 check(googol.clone(), n);
74 }
75 }
76
77 #[test]
78 fn test_nth_root_twos() {
79 const EXP: u32 = 12;
80 const LOG2: usize = 1 << EXP;
81 let x = BigUint::one() << LOG2;
82
83 // the perfect divisors are just powers of two
84 for exp in 1..=EXP {
85 let n = 2u32.pow(exp);
86 let expected = BigUint::one() << (LOG2 / n as usize);
87 assert_eq!(x.nth_root(n), expected);
88 check(x.clone(), n);
89 }
90
91 // degenerate cases should return quickly
92 assert!(x.nth_root(x.bits() as u32).is_one());
93 assert!(x.nth_root(i32::MAX as u32).is_one());
94 assert!(x.nth_root(u32::MAX).is_one());
95 }
96
97 #[test]
98 fn test_roots_rand1() {
99 // A random input that found regressions
100 let s = "575981506858479247661989091587544744717244516135539456183849\
101 986593934723426343633698413178771587697273822147578889823552\
102 182702908597782734558103025298880194023243541613924361007059\
103 353344183590348785832467726433749431093350684849462759540710\
104 026019022227591412417064179299354183441181373862905039254106\
105 4781867";
106 let x: BigUint = s.parse().unwrap();
107
108 check(x.clone(), 2);
109 check(x.clone(), 3);
110 check(x.clone(), 10);
111 check(x, 100);
112 }
113}
114
115mod bigint {
116 use num_bigint::BigInt;
117 use num_traits::Signed;
118
119 fn check(x: i64, n: u32) {
120 let big_x = BigInt::from(x);
121 let res = big_x.nth_root(n);
122
123 if n == 2 {
124 assert_eq!(&res, &big_x.sqrt())
125 } else if n == 3 {
126 assert_eq!(&res, &big_x.cbrt())
127 }
128
129 if big_x.is_negative() {
130 assert!(res.pow(n) >= big_x);
131 assert!((res - 1u32).pow(n) < big_x);
132 } else {
133 assert!(res.pow(n) <= big_x);
134 assert!((res + 1u32).pow(n) > big_x);
135 }
136 }
137
138 #[test]
139 fn test_nth_root() {
140 check(-100, 3);
141 }
142
143 #[test]
144 #[should_panic]
145 fn test_nth_root_x_neg_n_even() {
146 check(-100, 4);
147 }
148
149 #[test]
150 #[should_panic]
151 fn test_sqrt_x_neg() {
152 check(-4, 2);
153 }
154
155 #[test]
156 fn test_cbrt() {
157 check(8, 3);
158 check(-8, 3);
159 }
160}