blob: bf05e75f7387cd80eeaf7dcc3be41e50f99c5782 [file] [log] [blame]
David Benjaminc895d6b2016-08-11 13:26:41 -04001// Copyright (c) 2016, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package main
16
17import (
18 "bufio"
19 "errors"
20 "fmt"
21 "io"
22 "math/big"
23 "os"
24 "strings"
25)
26
27type test struct {
28 LineNumber int
29 Type string
30 Values map[string]*big.Int
31}
32
33type testScanner struct {
34 scanner *bufio.Scanner
35 lineNo int
36 err error
37 test test
38}
39
40func newTestScanner(r io.Reader) *testScanner {
41 return &testScanner{scanner: bufio.NewScanner(r)}
42}
43
44func (s *testScanner) scanLine() bool {
45 if !s.scanner.Scan() {
46 return false
47 }
48 s.lineNo++
49 return true
50}
51
52func (s *testScanner) addAttribute(line string) (key string, ok bool) {
53 fields := strings.SplitN(line, "=", 2)
54 if len(fields) != 2 {
55 s.setError(errors.New("invalid syntax"))
56 return "", false
57 }
58
59 key = strings.TrimSpace(fields[0])
60 value := strings.TrimSpace(fields[1])
61
62 valueInt, ok := new(big.Int).SetString(value, 16)
63 if !ok {
64 s.setError(fmt.Errorf("could not parse %q", value))
65 return "", false
66 }
67 if _, dup := s.test.Values[key]; dup {
68 s.setError(fmt.Errorf("duplicate key %q", key))
69 return "", false
70 }
71 s.test.Values[key] = valueInt
72 return key, true
73}
74
75func (s *testScanner) Scan() bool {
76 s.test = test{
77 Values: make(map[string]*big.Int),
78 }
79
80 // Scan until the first attribute.
81 for {
82 if !s.scanLine() {
83 return false
84 }
85 if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
86 break
87 }
88 }
89
90 var ok bool
91 s.test.Type, ok = s.addAttribute(s.scanner.Text())
92 if !ok {
93 return false
94 }
95 s.test.LineNumber = s.lineNo
96
97 for s.scanLine() {
98 if len(s.scanner.Text()) == 0 {
99 break
100 }
101
102 if s.scanner.Text()[0] == '#' {
103 continue
104 }
105
106 if _, ok := s.addAttribute(s.scanner.Text()); !ok {
107 return false
108 }
109 }
110 return s.scanner.Err() == nil
111}
112
113func (s *testScanner) Test() test {
114 return s.test
115}
116
117func (s *testScanner) Err() error {
118 if s.err != nil {
119 return s.err
120 }
121 return s.scanner.Err()
122}
123
124func (s *testScanner) setError(err error) {
125 s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
126}
127
128func checkKeys(t test, keys ...string) bool {
129 var foundErrors bool
130
131 for _, k := range keys {
132 if _, ok := t.Values[k]; !ok {
133 fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
134 foundErrors = true
135 }
136 }
137
138 for k, _ := range t.Values {
139 var found bool
140 for _, k2 := range keys {
141 if k == k2 {
142 found = true
143 break
144 }
145 }
146 if !found {
147 fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
148 foundErrors = true
149 }
150 }
151
152 return !foundErrors
153}
154
155func checkResult(t test, expr, key string, r *big.Int) {
156 if t.Values[key].Cmp(r) != 0 {
157 fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16))
158 }
159}
160
161func main() {
162 if len(os.Args) != 2 {
163 fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0])
164 os.Exit(1)
165 }
166
167 in, err := os.Open(os.Args[1])
168 if err != nil {
169 fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
170 os.Exit(1)
171 }
172 defer in.Close()
173
174 scanner := newTestScanner(in)
175 for scanner.Scan() {
176 test := scanner.Test()
177 switch test.Type {
178 case "Sum":
179 if checkKeys(test, "A", "B", "Sum") {
180 r := new(big.Int).Add(test.Values["A"], test.Values["B"])
181 checkResult(test, "A + B", "Sum", r)
182 }
183 case "LShift1":
184 if checkKeys(test, "A", "LShift1") {
185 r := new(big.Int).Add(test.Values["A"], test.Values["A"])
186 checkResult(test, "A + A", "LShift1", r)
187 }
188 case "LShift":
189 if checkKeys(test, "A", "N", "LShift") {
190 r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64()))
191 checkResult(test, "A << N", "LShift", r)
192 }
193 case "RShift":
194 if checkKeys(test, "A", "N", "RShift") {
195 r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64()))
196 checkResult(test, "A >> N", "RShift", r)
197 }
198 case "Square":
199 if checkKeys(test, "A", "Square") {
200 r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
201 checkResult(test, "A * A", "Square", r)
202 }
203 case "Product":
204 if checkKeys(test, "A", "B", "Product") {
205 r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
206 checkResult(test, "A * B", "Product", r)
207 }
208 case "Quotient":
209 if checkKeys(test, "A", "B", "Quotient", "Remainder") {
210 q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int))
211 checkResult(test, "A / B", "Quotient", q)
212 checkResult(test, "A % B", "Remainder", r)
213 }
214 case "ModMul":
215 if checkKeys(test, "A", "B", "M", "ModMul") {
216 r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
217 r = r.Mod(r, test.Values["M"])
218 checkResult(test, "A * B (mod M)", "ModMul", r)
219 }
220 case "ModExp":
221 if checkKeys(test, "A", "E", "M", "ModExp") {
222 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"])
223 checkResult(test, "A ^ E (mod M)", "ModExp", r)
224 }
225 case "Exp":
226 if checkKeys(test, "A", "E", "Exp") {
227 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil)
228 checkResult(test, "A ^ E", "Exp", r)
229 }
230 case "ModSqrt":
231 bigOne := new(big.Int).SetInt64(1)
232 bigTwo := new(big.Int).SetInt64(2)
233
234 if checkKeys(test, "A", "P", "ModSqrt") {
235 test.Values["A"].Mod(test.Values["A"], test.Values["P"])
236
237 r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"])
238 r = r.Mod(r, test.Values["P"])
239 checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r)
240
241 if test.Values["P"].Cmp(bigTwo) > 0 {
242 pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne)
243 pMinus1Over2.Rsh(pMinus1Over2, 1)
244
245 if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 {
246 fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber)
247 }
248 }
249 }
250 case "ModInv":
251 if checkKeys(test, "A", "M", "ModInv") {
252 r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"])
253 checkResult(test, "A ^ -1 (mod M)", "ModInv", r)
254 }
Robert Sloanab8b8882018-03-26 11:39:51 -0700255 case "ModSquare":
256 if checkKeys(test, "A", "M", "ModSquare") {
257 r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
258 r = r.Mod(r, test.Values["M"])
259 checkResult(test, "A * A (mod M)", "ModSquare", r)
260 }
261 case "NotModSquare":
262 if checkKeys(test, "P", "NotModSquare") {
263 if new(big.Int).ModSqrt(test.Values["NotModSquare"], test.Values["P"]) != nil {
264 fmt.Fprintf(os.Stderr, "Line %d: value was a square.\n", test.LineNumber)
265 }
266 }
267 case "GCD":
Robert Sloan49d063b2018-04-03 11:30:38 -0700268 if checkKeys(test, "A", "B", "GCD", "LCM") {
Robert Sloanab8b8882018-03-26 11:39:51 -0700269 a := test.Values["A"]
270 b := test.Values["B"]
271 // Go's GCD function does not accept zero, unlike OpenSSL.
272 var g *big.Int
273 if a.Sign() == 0 {
274 g = b
275 } else if b.Sign() == 0 {
276 g = a
277 } else {
278 g = new(big.Int).GCD(nil, nil, a, b)
279 }
280 checkResult(test, "GCD(A, B)", "GCD", g)
Robert Sloan49d063b2018-04-03 11:30:38 -0700281 if g.Sign() != 0 {
282 lcm := new(big.Int).Mul(a, b)
283 lcm = lcm.Div(lcm, g)
284 checkResult(test, "LCM(A, B)", "LCM", lcm)
285 }
Robert Sloanab8b8882018-03-26 11:39:51 -0700286 }
David Benjaminc895d6b2016-08-11 13:26:41 -0400287 default:
288 fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type)
289 }
290 }
291 if scanner.Err() != nil {
292 fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
293 }
294}