- djm@cvs.openbsd.org 2005/11/04 05:15:59
     [kex.c kex.h kexdh.c kexdhc.c kexdhs.c kexgex.c kexgexc.c kexgexs.c]
     remove hardcoded hash lengths in key exchange code, allowing
     implementation of KEX methods with different hashes (e.g. SHA-256);
     ok markus@ dtucker@ stevesk@
diff --git a/kex.c b/kex.c
index 5dce335..cd71be9 100644
--- a/kex.c
+++ b/kex.c
@@ -23,7 +23,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: kex.c,v 1.64 2005/07/25 11:59:39 markus Exp $");
+RCSID("$OpenBSD: kex.c,v 1.65 2005/11/04 05:15:59 djm Exp $");
 
 #include <openssl/crypto.h>
 
@@ -294,13 +294,17 @@
 		fatal("no kex alg");
 	if (strcmp(k->name, KEX_DH1) == 0) {
 		k->kex_type = KEX_DH_GRP1_SHA1;
+		k->evp_md = EVP_sha1();
 	} else if (strcmp(k->name, KEX_DH14) == 0) {
 		k->kex_type = KEX_DH_GRP14_SHA1;
-	} else if (strcmp(k->name, KEX_DHGEX) == 0) {
+		k->evp_md = EVP_sha1();
+	} else if (strcmp(k->name, KEX_DHGEX_SHA1) == 0) {
 		k->kex_type = KEX_DH_GEX_SHA1;
+		k->evp_md = EVP_sha1();
 	} else
 		fatal("bad kex alg %s", k->name);
 }
+
 static void
 choose_hostkeyalg(Kex *k, char *client, char *server)
 {
@@ -404,28 +408,28 @@
 }
 
 static u_char *
-derive_key(Kex *kex, int id, u_int need, u_char *hash, BIGNUM *shared_secret)
+derive_key(Kex *kex, int id, u_int need, u_char *hash, u_int hashlen,
+    BIGNUM *shared_secret)
 {
 	Buffer b;
-	const EVP_MD *evp_md = EVP_sha1();
 	EVP_MD_CTX md;
 	char c = id;
 	u_int have;
-	int mdsz = EVP_MD_size(evp_md);
+	int mdsz;
 	u_char *digest;
 
-	if (mdsz < 0)
-		fatal("derive_key: mdsz < 0");
-	digest = xmalloc(roundup(need, mdsz));
+	if ((mdsz = EVP_MD_size(kex->evp_md)) <= 0)
+		fatal("bad kex md size %d", mdsz);
+ 	digest = xmalloc(roundup(need, mdsz));
 
 	buffer_init(&b);
 	buffer_put_bignum2(&b, shared_secret);
 
 	/* K1 = HASH(K || H || "A" || session_id) */
-	EVP_DigestInit(&md, evp_md);
+	EVP_DigestInit(&md, kex->evp_md);
 	if (!(datafellows & SSH_BUG_DERIVEKEY))
 		EVP_DigestUpdate(&md, buffer_ptr(&b), buffer_len(&b));
-	EVP_DigestUpdate(&md, hash, mdsz);
+	EVP_DigestUpdate(&md, hash, hashlen);
 	EVP_DigestUpdate(&md, &c, 1);
 	EVP_DigestUpdate(&md, kex->session_id, kex->session_id_len);
 	EVP_DigestFinal(&md, digest, NULL);
@@ -436,10 +440,10 @@
 	 * Key = K1 || K2 || ... || Kn
 	 */
 	for (have = mdsz; need > have; have += mdsz) {
-		EVP_DigestInit(&md, evp_md);
+		EVP_DigestInit(&md, kex->evp_md);
 		if (!(datafellows & SSH_BUG_DERIVEKEY))
 			EVP_DigestUpdate(&md, buffer_ptr(&b), buffer_len(&b));
-		EVP_DigestUpdate(&md, hash, mdsz);
+		EVP_DigestUpdate(&md, hash, hashlen);
 		EVP_DigestUpdate(&md, digest, have);
 		EVP_DigestFinal(&md, digest + have, NULL);
 	}
@@ -455,13 +459,15 @@
 
 #define NKEYS	6
 void
-kex_derive_keys(Kex *kex, u_char *hash, BIGNUM *shared_secret)
+kex_derive_keys(Kex *kex, u_char *hash, u_int hashlen, BIGNUM *shared_secret)
 {
 	u_char *keys[NKEYS];
 	u_int i, mode, ctos;
 
-	for (i = 0; i < NKEYS; i++)
-		keys[i] = derive_key(kex, 'A'+i, kex->we_need, hash, shared_secret);
+	for (i = 0; i < NKEYS; i++) {
+		keys[i] = derive_key(kex, 'A'+i, kex->we_need, hash, hashlen,
+		    shared_secret);
+	}
 
 	debug2("kex_derive_keys");
 	for (mode = 0; mode < MODE_MAX; mode++) {