upstream: sshsig tweaks and improvements from and suggested by

Markus

ok markus/me

OpenBSD-Commit-ID: ea4f46ad5a16b27af96e08c4877423918c4253e9
diff --git a/sshsig.c b/sshsig.c
index 0a1e146..c1f2d80 100644
--- a/sshsig.c
+++ b/sshsig.c
@@ -230,7 +230,7 @@
 		return r;
 	}
 
-	if (sversion < SIG_VERSION) {
+	if (sversion > SIG_VERSION) {
 		error("Signature version %lu is larger than supported "
 		    "version %u", (unsigned long)sversion, SIG_VERSION);
 		return SSH_ERR_INVALID_FORMAT;
@@ -241,7 +241,8 @@
 static int
 sshsig_check_hashalg(const char *hashalg)
 {
-	if (match_pattern_list(hashalg, HASHALG_ALLOWED, 0) == 1)
+	if (hashalg == NULL ||
+	    match_pattern_list(hashalg, HASHALG_ALLOWED, 0) == 1)
 		return 0;
 	error("%s: unsupported hash algorithm \"%.100s\"", __func__, hashalg);
 	return SSH_ERR_SIGN_ALG_UNSUPPORTED;
@@ -268,8 +269,6 @@
 		error("Couldn't parse signature blob: %s", ssh_err(r));
 		goto done;
 	}
-	if ((r = sshsig_check_hashalg(hashalg)) != 0)
-		goto done;
 
 	/* success */
 	r = 0;
@@ -293,6 +292,7 @@
 	char *got_namespace = NULL, *sigtype = NULL, *sig_hashalg = NULL;
 	size_t siglen;
 
+	debug("%s: verify message length %zu", __func__, sshbuf_len(h_message));
 	if (sign_keyp != NULL)
 		*sign_keyp = NULL;
 
@@ -301,9 +301,6 @@
 		r = SSH_ERR_ALLOC_FAIL;
 		goto done;
 	}
-	if ((r = sshsig_check_hashalg(hashalg)) != 0)
-		goto done;
-
 	if ((r = sshbuf_put(toverify, MAGIC_PREAMBLE,
 	    MAGIC_PREAMBLE_LEN)) != 0 ||
 	    (r = sshbuf_put_cstring(toverify, expect_namespace)) != 0 ||
@@ -382,36 +379,65 @@
 	return r;
 }
 
-int
-sshsig_sign_message(struct sshkey *key, const char *hashalg,
-    const struct sshbuf *message, const char *sig_namespace,
-    struct sshbuf **out, sshsig_signer *signer, void *signer_ctx)
+static int
+hash_buffer(const struct sshbuf *m, const char *hashalg, struct sshbuf **bp)
 {
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
-	struct sshbuf *b = NULL;
+	char *hex, hash[SSH_DIGEST_MAX_LENGTH];
 	int alg, r = SSH_ERR_INTERNAL_ERROR;
+	struct sshbuf *b = NULL;
 
-	if (out != NULL)
-		*out = NULL;
-	if (hashalg == NULL)
-	    hashalg = HASHALG_DEFAULT;
+	*bp = NULL;
+	memset(hash, 0, sizeof(hash));
 
 	if ((r = sshsig_check_hashalg(hashalg)) != 0)
 		return r;
 	if ((alg = ssh_digest_alg_by_name(hashalg)) == -1) {
 		error("%s: can't look up hash algorithm %s",
-		    __func__, HASHALG_DEFAULT);
+		    __func__, hashalg);
 		return SSH_ERR_INTERNAL_ERROR;
 	}
-	if ((r = ssh_digest_buffer(alg, message, hash, sizeof(hash))) != 0) {
+	if ((r = ssh_digest_buffer(alg, m, hash, sizeof(hash))) != 0) {
 		error("%s: ssh_digest_buffer failed: %s", __func__, ssh_err(r));
 		return r;
 	}
-	if ((b = sshbuf_from(hash, ssh_digest_bytes(alg))) == NULL) {
-		error("%s: sshbuf_from failed", __func__);
+	if ((hex = tohex(hash, ssh_digest_bytes(alg))) != NULL) {
+		debug3("%s: final hash: %s", __func__, hex);
+		freezero(hex, strlen(hex));
+	}
+	if ((b = sshbuf_new()) == NULL) {
 		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
+	if ((r = sshbuf_put(b, hash, ssh_digest_bytes(alg))) != 0) {
+		error("%s: sshbuf_put: %s", __func__, ssh_err(r));
+		goto out;
+	}
+	*bp = b;
+	b = NULL; /* transferred */
+	/* success */
+	r = 0;
+ out:
+	sshbuf_free(b);
+	explicit_bzero(hash, sizeof(hash));
+	return 0;
+}
+
+int
+sshsig_signb(struct sshkey *key, const char *hashalg,
+    const struct sshbuf *message, const char *sig_namespace,
+    struct sshbuf **out, sshsig_signer *signer, void *signer_ctx)
+{
+	struct sshbuf *b = NULL;
+	int r = SSH_ERR_INTERNAL_ERROR;
+
+	if (hashalg == NULL)
+		hashalg = HASHALG_DEFAULT;
+	if (out != NULL)
+		*out = NULL;
+	if ((r = hash_buffer(message, hashalg, &b)) != 0) {
+		error("%s: hash_buffer failed: %s", __func__, ssh_err(r));
+		goto out;
+	}
 	if ((r = sshsig_wrap_sign(key, hashalg, b, sig_namespace, out,
 	    signer, signer_ctx)) != 0)
 		goto out;
@@ -419,17 +445,15 @@
 	r = 0;
  out:
 	sshbuf_free(b);
-	explicit_bzero(hash, sizeof(hash));
 	return r;
 }
 
 int
-sshsig_verify_message(struct sshbuf *signature, const struct sshbuf *message,
+sshsig_verifyb(struct sshbuf *signature, const struct sshbuf *message,
     const char *expect_namespace, struct sshkey **sign_keyp)
 {
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
 	struct sshbuf *b = NULL;
-	int alg, r = SSH_ERR_INTERNAL_ERROR;
+	int r = SSH_ERR_INTERNAL_ERROR;
 	char *hashalg = NULL;
 
 	if (sign_keyp != NULL)
@@ -437,18 +461,9 @@
 
 	if ((r = sshsig_peek_hashalg(signature, &hashalg)) != 0)
 		return r;
-	if ((alg = ssh_digest_alg_by_name(hashalg)) == -1) {
-		error("%s: can't look up hash algorithm %s",
-		    __func__, HASHALG_DEFAULT);
-		return SSH_ERR_INTERNAL_ERROR;
-	}
-	if ((r = ssh_digest_buffer(alg, message, hash, sizeof(hash))) != 0) {
-		error("%s: ssh_digest_buffer failed: %s", __func__, ssh_err(r));
-		goto out;
-	}
-	if ((b = sshbuf_from(hash, ssh_digest_bytes(alg))) == NULL) {
-		error("%s: sshbuf_from failed", __func__);
-		r = SSH_ERR_ALLOC_FAIL;
+	debug("%s: signature made with hash \"%s\"", __func__, hashalg);
+	if ((r = hash_buffer(message, hashalg, &b)) != 0) {
+		error("%s: hash_buffer failed: %s", __func__, ssh_err(r));
 		goto out;
 	}
 	if ((r = sshsig_wrap_verify(signature, hashalg, b, expect_namespace,
@@ -459,20 +474,29 @@
  out:
 	sshbuf_free(b);
 	free(hashalg);
-	explicit_bzero(hash, sizeof(hash));
 	return r;
 }
 
 static int
-hash_file(int fd, int hashalg, u_char *hash, size_t hashlen)
+hash_file(int fd, const char *hashalg, struct sshbuf **bp)
 {
-	char *hex, rbuf[8192];
+	char *hex, rbuf[8192], hash[SSH_DIGEST_MAX_LENGTH];
 	ssize_t n, total = 0;
 	struct ssh_digest_ctx *ctx;
-	int r, oerrno;
+	int alg, oerrno, r = SSH_ERR_INTERNAL_ERROR;
+	struct sshbuf *b = NULL;
 
-	memset(hash, 0, hashlen);
-	if ((ctx = ssh_digest_start(hashalg)) == NULL) {
+	*bp = NULL;
+	memset(hash, 0, sizeof(hash));
+
+	if ((r = sshsig_check_hashalg(hashalg)) != 0)
+		return r;
+	if ((alg = ssh_digest_alg_by_name(hashalg)) == -1) {
+		error("%s: can't look up hash algorithm %s",
+		    __func__, hashalg);
+		return SSH_ERR_INTERNAL_ERROR;
+	}
+	if ((ctx = ssh_digest_start(alg)) == NULL) {
 		error("%s: ssh_digest_start failed", __func__);
 		return SSH_ERR_INTERNAL_ERROR;
 	}
@@ -484,7 +508,8 @@
 			error("%s: read: %s", __func__, strerror(errno));
 			ssh_digest_free(ctx);
 			errno = oerrno;
-			return SSH_ERR_SYSTEM_ERROR;
+			r = SSH_ERR_SYSTEM_ERROR;
+			goto out;
 		} else if (n == 0) {
 			debug2("%s: hashed %zu bytes", __func__, total);
 			break; /* EOF */
@@ -493,20 +518,33 @@
 		if ((r = ssh_digest_update(ctx, rbuf, (size_t)n)) != 0) {
 			error("%s: ssh_digest_update: %s",
 			    __func__, ssh_err(r));
-			ssh_digest_free(ctx);
-			return r;
+			goto out;
 		}
 	}
-	if ((r = ssh_digest_final(ctx, hash, hashlen)) != 0) {
+	if ((r = ssh_digest_final(ctx, hash, sizeof(hash))) != 0) {
 		error("%s: ssh_digest_final: %s", __func__, ssh_err(r));
-		ssh_digest_free(ctx);
+		goto out;
 	}
-	if ((hex = tohex(hash, hashlen)) != NULL) {
+	if ((hex = tohex(hash, ssh_digest_bytes(alg))) != NULL) {
 		debug3("%s: final hash: %s", __func__, hex);
 		freezero(hex, strlen(hex));
 	}
+	if ((b = sshbuf_new()) == NULL) {
+		r = SSH_ERR_ALLOC_FAIL;
+		goto out;
+	}
+	if ((r = sshbuf_put(b, hash, ssh_digest_bytes(alg))) != 0) {
+		error("%s: sshbuf_put: %s", __func__, ssh_err(r));
+		goto out;
+	}
+	*bp = b;
+	b = NULL; /* transferred */
 	/* success */
+	r = 0;
+ out:
+	sshbuf_free(b);
 	ssh_digest_free(ctx);
+	explicit_bzero(hash, sizeof(hash));
 	return 0;
 }
 
@@ -515,31 +553,17 @@
     int fd, const char *sig_namespace, struct sshbuf **out,
     sshsig_signer *signer, void *signer_ctx)
 {
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
 	struct sshbuf *b = NULL;
-	int alg, r = SSH_ERR_INTERNAL_ERROR;
+	int r = SSH_ERR_INTERNAL_ERROR;
 
+	if (hashalg == NULL)
+		hashalg = HASHALG_DEFAULT;
 	if (out != NULL)
 		*out = NULL;
-	if (hashalg == NULL)
-	    hashalg = HASHALG_DEFAULT;
-
-	if ((r = sshsig_check_hashalg(hashalg)) != 0)
-		return r;
-	if ((alg = ssh_digest_alg_by_name(hashalg)) == -1) {
-		error("%s: can't look up hash algorithm %s",
-		    __func__, HASHALG_DEFAULT);
-		return SSH_ERR_INTERNAL_ERROR;
-	}
-	if ((r = hash_file(fd, alg, hash, sizeof(hash))) != 0) {
+	if ((r = hash_file(fd, hashalg, &b)) != 0) {
 		error("%s: hash_file failed: %s", __func__, ssh_err(r));
 		return r;
 	}
-	if ((b = sshbuf_from(hash, ssh_digest_bytes(alg))) == NULL) {
-		error("%s: sshbuf_from failed", __func__);
-		r = SSH_ERR_ALLOC_FAIL;
-		goto out;
-	}
 	if ((r = sshsig_wrap_sign(key, hashalg, b, sig_namespace, out,
 	    signer, signer_ctx)) != 0)
 		goto out;
@@ -547,7 +571,6 @@
 	r = 0;
  out:
 	sshbuf_free(b);
-	explicit_bzero(hash, sizeof(hash));
 	return r;
 }
 
@@ -555,9 +578,8 @@
 sshsig_verify_fd(struct sshbuf *signature, int fd,
     const char *expect_namespace, struct sshkey **sign_keyp)
 {
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
 	struct sshbuf *b = NULL;
-	int alg, r = SSH_ERR_INTERNAL_ERROR;
+	int r = SSH_ERR_INTERNAL_ERROR;
 	char *hashalg = NULL;
 
 	if (sign_keyp != NULL)
@@ -565,18 +587,9 @@
 
 	if ((r = sshsig_peek_hashalg(signature, &hashalg)) != 0)
 		return r;
-	if ((alg = ssh_digest_alg_by_name(hashalg)) == -1) {
-		error("%s: can't look up hash algorithm %s",
-		    __func__, HASHALG_DEFAULT);
-		return SSH_ERR_INTERNAL_ERROR;
-	}
-	if ((r = hash_file(fd, alg, hash, sizeof(hash))) != 0) {
+	debug("%s: signature made with hash \"%s\"", __func__, hashalg);
+	if ((r = hash_file(fd, hashalg, &b)) != 0) {
 		error("%s: hash_file failed: %s", __func__, ssh_err(r));
-		return r;
-	}
-	if ((b = sshbuf_from(hash, ssh_digest_bytes(alg))) == NULL) {
-		error("%s: sshbuf_from failed", __func__);
-		r = SSH_ERR_ALLOC_FAIL;
 		goto out;
 	}
 	if ((r = sshsig_wrap_verify(signature, hashalg, b, expect_namespace,
@@ -587,7 +600,6 @@
  out:
 	sshbuf_free(b);
 	free(hashalg);
-	explicit_bzero(hash, sizeof(hash));
 	return r;
 }
 
@@ -769,14 +781,14 @@
 		linenum++;
 		r = check_allowed_keys_line(path, linenum, line, sign_key,
 		    principal, sig_namespace);
+		free(line);
+		line = NULL;
 		if (r == SSH_ERR_KEY_NOT_FOUND)
 			continue;
 		else if (r == 0) {
 			/* success */
 			fclose(f);
-			free(line);
 			return 0;
-			/* XXX continue and check revocation? */
 		} else
 			break;
 	}