blob: f877c7557d7e60e71ae3264085eb020268005433 [file] [log] [blame]
Pete Bentley0c61efe2019-08-13 09:32:23 +01001/* Copyright (c) 2019, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15#include <vector>
16
17#include <assert.h>
18#include <string.h>
19#include <sys/uio.h>
20#include <unistd.h>
21#include <cstdarg>
22
23#include <openssl/aes.h>
24#include <openssl/sha.h>
25#include <openssl/span.h>
26
27static constexpr size_t kMaxArgs = 8;
28static constexpr size_t kMaxArgLength = (1 << 20);
29static constexpr size_t kMaxNameLength = 30;
30
31static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30),
32 "Argument limits permit excessive messages");
33
34using namespace bssl;
35
36static bool ReadAll(int fd, void *in_data, size_t data_len) {
37 uint8_t *data = reinterpret_cast<uint8_t *>(in_data);
38 size_t done = 0;
39
40 while (done < data_len) {
41 ssize_t r;
42 do {
43 r = read(fd, &data[done], data_len - done);
44 } while (r == -1 && errno == EINTR);
45
46 if (r <= 0) {
47 return false;
48 }
49
50 done += r;
51 }
52
53 return true;
54}
55
56template <typename... Args>
57static bool WriteReply(int fd, Args... args) {
58 std::vector<Span<const uint8_t>> spans = {args...};
59 if (spans.empty() || spans.size() > kMaxArgs) {
60 abort();
61 }
62
63 uint32_t nums[1 + kMaxArgs];
64 iovec iovs[kMaxArgs + 1];
65 nums[0] = spans.size();
66 iovs[0].iov_base = nums;
67 iovs[0].iov_len = sizeof(uint32_t) * (1 + spans.size());
68
69 for (size_t i = 0; i < spans.size(); i++) {
70 const auto &span = spans[i];
71 nums[i + 1] = span.size();
72 iovs[i + 1].iov_base = const_cast<uint8_t *>(span.data());
73 iovs[i + 1].iov_len = span.size();
74 }
75
76 const size_t num_iov = spans.size() + 1;
77 size_t iov_done = 0;
78 while (iov_done < num_iov) {
79 ssize_t r;
80 do {
81 r = writev(fd, &iovs[iov_done], num_iov - iov_done);
82 } while (r == -1 && errno == EINTR);
83
84 if (r <= 0) {
85 return false;
86 }
87
88 size_t written = r;
89 for (size_t i = iov_done; written > 0 && i < num_iov; i++) {
90 iovec &iov = iovs[i];
91
92 size_t done = written;
93 if (done > iov.iov_len) {
94 done = iov.iov_len;
95 }
96
97 iov.iov_base = reinterpret_cast<uint8_t *>(iov.iov_base) + done;
98 iov.iov_len -= done;
99 written -= done;
100
101 if (iov.iov_len == 0) {
102 iov_done++;
103 }
104 }
105
106 assert(written == 0);
107 }
108
109 return true;
110}
111
112static bool GetConfig(const Span<const uint8_t> args[]) {
113 static constexpr char kConfig[] =
114 "["
115 "{"
116 " \"algorithm\": \"SHA2-224\","
117 " \"revision\": \"1.0\","
118 " \"messageLength\": [{"
119 " \"min\": 0, \"max\": 65528, \"increment\": 8"
120 " }]"
121 "},"
122 "{"
123 " \"algorithm\": \"SHA2-256\","
124 " \"revision\": \"1.0\","
125 " \"messageLength\": [{"
126 " \"min\": 0, \"max\": 65528, \"increment\": 8"
127 " }]"
128 "},"
129 "{"
130 " \"algorithm\": \"SHA2-384\","
131 " \"revision\": \"1.0\","
132 " \"messageLength\": [{"
133 " \"min\": 0, \"max\": 65528, \"increment\": 8"
134 " }]"
135 "},"
136 "{"
137 " \"algorithm\": \"SHA2-512\","
138 " \"revision\": \"1.0\","
139 " \"messageLength\": [{"
140 " \"min\": 0, \"max\": 65528, \"increment\": 8"
141 " }]"
142 "},"
143 "{"
144 " \"algorithm\": \"SHA-1\","
145 " \"revision\": \"1.0\","
146 " \"messageLength\": [{"
147 " \"min\": 0, \"max\": 65528, \"increment\": 8"
148 " }]"
149 "},"
150 "{"
151 " \"algorithm\": \"ACVP-AES-ECB\","
152 " \"revision\": \"1.0\","
153 " \"direction\": [\"encrypt\", \"decrypt\"],"
154 " \"keyLen\": [128, 192, 256]"
155 "},"
156 "{"
157 " \"algorithm\": \"ACVP-AES-CBC\","
158 " \"revision\": \"1.0\","
159 " \"direction\": [\"encrypt\", \"decrypt\"],"
160 " \"keyLen\": [128, 192, 256]"
161 "}"
162 "]";
163 return WriteReply(
164 STDOUT_FILENO,
165 Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig),
166 sizeof(kConfig) - 1));
167}
168
169template <uint8_t *(*OneShotHash)(const uint8_t *, size_t, uint8_t *),
170 size_t DigestLength>
171static bool Hash(const Span<const uint8_t> args[]) {
172 uint8_t digest[DigestLength];
173 OneShotHash(args[0].data(), args[0].size(), digest);
174 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest));
175}
176
177template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
178 void (*Block)(const uint8_t *in, uint8_t *out, const AES_KEY *key)>
179static bool AES(const Span<const uint8_t> args[]) {
180 AES_KEY key;
181 if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
182 return false;
183 }
184 if (args[1].size() % AES_BLOCK_SIZE != 0) {
185 return false;
186 }
187
188 std::vector<uint8_t> out;
189 out.resize(args[1].size());
190 for (size_t i = 0; i < args[1].size(); i += AES_BLOCK_SIZE) {
191 Block(args[1].data() + i, &out[i], &key);
192 }
193 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
194}
195
196template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
197 int Direction>
198static bool AES_CBC(const Span<const uint8_t> args[]) {
199 AES_KEY key;
200 if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
201 return false;
202 }
203 if (args[1].size() % AES_BLOCK_SIZE != 0 ||
204 args[2].size() != AES_BLOCK_SIZE) {
205 return false;
206 }
207 uint8_t iv[AES_BLOCK_SIZE];
208 memcpy(iv, args[2].data(), AES_BLOCK_SIZE);
209
210 std::vector<uint8_t> out;
211 out.resize(args[1].size());
212 AES_cbc_encrypt(args[1].data(), out.data(), args[1].size(), &key, iv,
213 Direction);
214 return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
215}
216
217static constexpr struct {
218 const char name[kMaxNameLength + 1];
219 uint8_t expected_args;
220 bool (*handler)(const Span<const uint8_t>[]);
221} kFunctions[] = {
222 {"getConfig", 0, GetConfig},
223 {"SHA-1", 1, Hash<SHA1, SHA_DIGEST_LENGTH>},
224 {"SHA2-224", 1, Hash<SHA224, SHA224_DIGEST_LENGTH>},
225 {"SHA2-256", 1, Hash<SHA256, SHA256_DIGEST_LENGTH>},
226 {"SHA2-384", 1, Hash<SHA384, SHA256_DIGEST_LENGTH>},
227 {"SHA2-512", 1, Hash<SHA512, SHA512_DIGEST_LENGTH>},
228 {"AES/encrypt", 2, AES<AES_set_encrypt_key, AES_encrypt>},
229 {"AES/decrypt", 2, AES<AES_set_decrypt_key, AES_decrypt>},
230 {"AES-CBC/encrypt", 3, AES_CBC<AES_set_encrypt_key, AES_ENCRYPT>},
231 {"AES-CBC/decrypt", 3, AES_CBC<AES_set_decrypt_key, AES_DECRYPT>},
232};
233
234int main() {
235 uint32_t nums[1 + kMaxArgs];
236 uint8_t *buf = nullptr;
237 size_t buf_len = 0;
238 Span<const uint8_t> args[kMaxArgs];
239
240 for (;;) {
241 if (!ReadAll(STDIN_FILENO, nums, sizeof(uint32_t) * 2)) {
242 return 1;
243 }
244
245 const size_t num_args = nums[0];
246 if (num_args == 0) {
247 fprintf(stderr, "Invalid, zero-argument operation requested.\n");
248 return 2;
249 } else if (num_args > kMaxArgs) {
250 fprintf(stderr,
251 "Operation requested with %zu args, but %zu is the limit.\n",
252 num_args, kMaxArgs);
253 return 2;
254 }
255
256 if (num_args > 1 &&
257 !ReadAll(STDIN_FILENO, &nums[2], sizeof(uint32_t) * (num_args - 1))) {
258 return 1;
259 }
260
261 size_t need = 0;
262 for (size_t i = 0; i < num_args; i++) {
263 const size_t arg_length = nums[i + 1];
264 if (i == 0 && arg_length > kMaxNameLength) {
265 fprintf(stderr,
266 "Operation with name of length %zu exceeded limit of %zu.\n",
267 arg_length, kMaxNameLength);
268 return 2;
269 } else if (arg_length > kMaxArgLength) {
270 fprintf(
271 stderr,
272 "Operation with argument of length %zu exceeded limit of %zu.\n",
273 arg_length, kMaxArgLength);
274 return 2;
275 }
276
277 // static_assert around kMaxArgs etc enforces that this doesn't overflow.
278 need += arg_length;
279 }
280
281 if (need > buf_len) {
282 free(buf);
283 size_t alloced = need + (need >> 1);
284 if (alloced < need) {
285 abort();
286 }
287 buf = reinterpret_cast<uint8_t *>(malloc(alloced));
288 if (buf == nullptr) {
289 abort();
290 }
291 buf_len = alloced;
292 }
293
294 if (!ReadAll(STDIN_FILENO, buf, need)) {
295 return 1;
296 }
297
298 size_t offset = 0;
299 for (size_t i = 0; i < num_args; i++) {
300 args[i] = Span<const uint8_t>(&buf[offset], nums[i + 1]);
301 offset += nums[i + 1];
302 }
303
304 bool found = true;
305 for (const auto &func : kFunctions) {
306 if (args[0].size() == strlen(func.name) &&
307 memcmp(args[0].data(), func.name, args[0].size()) == 0) {
308 if (num_args - 1 != func.expected_args) {
309 fprintf(stderr,
310 "\'%s\' operation received %zu arguments but expected %u.\n",
311 func.name, num_args - 1, func.expected_args);
312 return 2;
313 }
314
315 if (!func.handler(&args[1])) {
316 return 4;
317 }
318
319 found = true;
320 break;
321 }
322 }
323
324 if (!found) {
325 const std::string name(reinterpret_cast<const char *>(args[0].data()),
326 args[0].size());
327 fprintf(stderr, "Unknown operation: %s\n", name.c_str());
328 return 3;
329 }
330 }
331}