- markus@cvs.openbsd.org 2001/04/03 23:32:12
     [kex.c kex.h packet.c sshconnect2.c sshd.c]
     undo parts of recent my changes: main part of keyexchange does not
     need dispatch-callbacks, since application data is delayed until
     the keyexchange completes (if i understand the drafts correctly).
     add some infrastructure for re-keying.
diff --git a/ChangeLog b/ChangeLog
index 6c37016..df08299 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -13,6 +13,12 @@
    - todd@cvs.openbsd.org 2001/04/03 21:19:38
      [ssh_config]
      id_rsa1/2 -> id_rsa; ok markus@
+   - markus@cvs.openbsd.org 2001/04/03 23:32:12
+     [kex.c kex.h packet.c sshconnect2.c sshd.c]
+     undo parts of recent my changes: main part of keyexchange does not
+     need dispatch-callbacks, since application data is delayed until
+     the keyexchange completes (if i understand the drafts correctly).
+     add some infrastructure for re-keying.
 
 20010403
  - OpenBSD CVS Sync
@@ -4823,4 +4829,4 @@
  - Wrote replacements for strlcpy and mkdtemp
  - Released 1.0pre1
 
-$Id: ChangeLog,v 1.1051 2001/04/04 01:58:48 mouring Exp $
+$Id: ChangeLog,v 1.1052 2001/04/04 02:00:54 mouring Exp $
diff --git a/kex.c b/kex.c
index a0a5b46..3b42d32 100644
--- a/kex.c
+++ b/kex.c
@@ -23,7 +23,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: kex.c,v 1.26 2001/04/03 19:53:29 markus Exp $");
+RCSID("$OpenBSD: kex.c,v 1.27 2001/04/03 23:32:11 markus Exp $");
 
 #include <openssl/crypto.h>
 
@@ -131,7 +131,7 @@
 	for (i = 30; i <= 49; i++)
 		dispatch_set(i, &kex_protocol_error);
 	buffer_clear(&kex->peer);
-	buffer_clear(&kex->my);
+	/* buffer_clear(&kex->my); */
 	kex->flags &= ~KEX_INIT_SENT;
 }
 
@@ -152,7 +152,6 @@
 	int dlen;
 	Kex *kex = (Kex *)ctxt;
 
-	dispatch_set(SSH2_MSG_KEXINIT, &kex_protocol_error);
 	debug("SSH2_MSG_KEXINIT received");
 
 	ptr = packet_get_raw(&dlen);
@@ -274,18 +273,20 @@
 }
 
 void
-kex_choose_conf(Kex *k)
+kex_choose_conf(Kex *kex)
 {
+	Newkeys *newkeys;
 	char **my, **peer;
 	char **cprop, **sprop;
+	int nenc, nmac, ncomp;
 	int mode;
 	int ctos;				/* direction: if true client-to-server */
 	int need;
 
-	my   = kex_buf2prop(&k->my);
-	peer = kex_buf2prop(&k->peer);
+	my   = kex_buf2prop(&kex->my);
+	peer = kex_buf2prop(&kex->peer);
 
-	if (k->server) {
+	if (kex->server) {
 		cprop=peer;
 		sprop=my;
 	} else {
@@ -294,42 +295,44 @@
 	}
 
 	for (mode = 0; mode < MODE_MAX; mode++) {
-		int nenc, nmac, ncomp;
-		ctos = (!k->server && mode == MODE_OUT) || (k->server && mode == MODE_IN);
+		newkeys = xmalloc(sizeof(*newkeys));
+		memset(newkeys, 0, sizeof(*newkeys));
+		kex->keys[mode] = newkeys;
+		ctos = (!kex->server && mode == MODE_OUT) || (kex->server && mode == MODE_IN);
 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
-		choose_enc (&k->enc [mode], cprop[nenc],  sprop[nenc]);
-		choose_mac (&k->mac [mode], cprop[nmac],  sprop[nmac]);
-		choose_comp(&k->comp[mode], cprop[ncomp], sprop[ncomp]);
+		choose_enc (&newkeys->enc,  cprop[nenc],  sprop[nenc]);
+		choose_mac (&newkeys->mac,  cprop[nmac],  sprop[nmac]);
+		choose_comp(&newkeys->comp, cprop[ncomp], sprop[ncomp]);
 		debug("kex: %s %s %s %s",
 		    ctos ? "client->server" : "server->client",
-		    k->enc[mode].name,
-		    k->mac[mode].name,
-		    k->comp[mode].name);
+		    newkeys->enc.name,
+		    newkeys->mac.name,
+		    newkeys->comp.name);
 	}
-	choose_kex(k, cprop[PROPOSAL_KEX_ALGS], sprop[PROPOSAL_KEX_ALGS]);
-	choose_hostkeyalg(k, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
+	choose_kex(kex, cprop[PROPOSAL_KEX_ALGS], sprop[PROPOSAL_KEX_ALGS]);
+	choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS]);
 	need = 0;
 	for (mode = 0; mode < MODE_MAX; mode++) {
-	    if (need < k->enc[mode].cipher->key_len)
-		    need = k->enc[mode].cipher->key_len;
-	    if (need < k->enc[mode].cipher->block_size)
-		    need = k->enc[mode].cipher->block_size;
-	    if (need < k->mac[mode].key_len)
-		    need = k->mac[mode].key_len;
+		newkeys = kex->keys[mode];
+		if (need < newkeys->enc.cipher->key_len)
+			need = newkeys->enc.cipher->key_len;
+		if (need < newkeys->enc.cipher->block_size)
+			need = newkeys->enc.cipher->block_size;
+		if (need < newkeys->mac.key_len)
+			need = newkeys->mac.key_len;
 	}
 	/* XXX need runden? */
-	k->we_need = need;
+	kex->we_need = need;
 
 	kex_prop_free(my);
 	kex_prop_free(peer);
-
 }
 
 u_char *
-derive_key(int id, int need, u_char *hash, BIGNUM *shared_secret)
+derive_key(Kex *kex, int id, int need, u_char *hash, BIGNUM *shared_secret)
 {
 	Buffer b;
 	EVP_MD *evp_md = EVP_sha1();
@@ -346,7 +349,7 @@
 	EVP_DigestUpdate(&md, buffer_ptr(&b), buffer_len(&b));	/* shared_secret K */
 	EVP_DigestUpdate(&md, hash, mdsz);		/* transport-06 */
 	EVP_DigestUpdate(&md, &c, 1);			/* key id */
-	EVP_DigestUpdate(&md, hash, mdsz);		/* session id */
+	EVP_DigestUpdate(&md, kex->session_id, kex->session_id_len);
 	EVP_DigestFinal(&md, digest, NULL);
 
 	/* expand */
@@ -365,26 +368,36 @@
 	return digest;
 }
 
+Newkeys *x_newkeys[MODE_MAX];
+
 #define NKEYS	6
 void
-kex_derive_keys(Kex *k, u_char *hash, BIGNUM *shared_secret)
+kex_derive_keys(Kex *kex, u_char *hash, BIGNUM *shared_secret)
 {
-	int i;
-	int mode;
-	int ctos;
+	Newkeys *newkeys;
 	u_char *keys[NKEYS];
+	int i, mode, ctos;
 
 	for (i = 0; i < NKEYS; i++)
-		keys[i] = derive_key('A'+i, k->we_need, hash, shared_secret);
+		keys[i] = derive_key(kex, 'A'+i, kex->we_need, hash, shared_secret);
 
+	debug("kex_derive_keys");
 	for (mode = 0; mode < MODE_MAX; mode++) {
-		ctos = (!k->server && mode == MODE_OUT) || (k->server && mode == MODE_IN);
-		k->enc[mode].iv  = keys[ctos ? 0 : 1];
-		k->enc[mode].key = keys[ctos ? 2 : 3];
-		k->mac[mode].key = keys[ctos ? 4 : 5];
+		newkeys = kex->keys[mode];
+		ctos = (!kex->server && mode == MODE_OUT) || (kex->server && mode == MODE_IN);
+		newkeys->enc.iv  = keys[ctos ? 0 : 1];
+		newkeys->enc.key = keys[ctos ? 2 : 3];
+		newkeys->mac.key = keys[ctos ? 4 : 5];
+		x_newkeys[mode] = newkeys;
 	}
 }
 
+Newkeys *
+kex_get_newkeys(int mode)
+{
+	return x_newkeys[mode];
+}
+
 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH)
 void
 dump_digest(char *msg, u_char *digest, int len)
diff --git a/kex.h b/kex.h
index 58f6d82..83f54fd 100644
--- a/kex.h
+++ b/kex.h
@@ -1,4 +1,4 @@
-/*	$OpenBSD: kex.h,v 1.18 2001/04/03 19:53:29 markus Exp $	*/
+/*	$OpenBSD: kex.h,v 1.19 2001/04/03 23:32:12 markus Exp $	*/
 
 /*
  * Copyright (c) 2000 Markus Friedl.  All rights reserved.
@@ -59,64 +59,69 @@
 	DH_GEX_SHA1
 };
 
+#define KEX_INIT_SENT	0x0001
+
 typedef struct Kex Kex;
 typedef struct Mac Mac;
 typedef struct Comp Comp;
 typedef struct Enc Enc;
+typedef struct Newkeys Newkeys;
 
 struct Enc {
-	char		*name;
-	Cipher		*cipher;
-	int		enabled;
+	char	*name;
+	Cipher	*cipher;
+	int	enabled;
 	u_char	*key;
 	u_char	*iv;
 };
 struct Mac {
-	char		*name;
-	int		enabled;
-	EVP_MD		*md;
-	int		mac_len;
+	char	*name;
+	int	enabled;
+	EVP_MD	*md;
+	int	mac_len;
 	u_char	*key;
-	int		key_len;
+	int	key_len;
 };
 struct Comp {
-	int		type;
-	int		enabled;
-	char		*name;
+	int	type;
+	int	enabled;
+	char	*name;
 };
-#define KEX_INIT_SENT	0x0001
+struct Newkeys {
+	Enc	enc;
+	Mac	mac;
+	Comp	comp;
+};
 struct Kex {
-	Enc		enc [MODE_MAX];
-	Mac		mac [MODE_MAX];
-	Comp		comp[MODE_MAX];
-	int		we_need;
-	int		server;
-	char		*name;
-	int		hostkey_type;
-	int		kex_type;
-
-	/* used during kex */
-	Buffer		my;
-	Buffer		peer;
-	int		newkeys;
-	int		flags;
-	void		*state;
-	char		*client_version_string;
-	char		*server_version_string;
-
-	int		(*check_host_key)(Key *hostkey);
-	Key		*(*load_host_key)(int type);
+	u_char	*session_id;
+	int	session_id_len;
+	Newkeys	*keys[MODE_MAX];
+	int	we_need;
+	int	server;
+	char	*name;
+	int	hostkey_type;
+	int	kex_type;
+	Buffer	my;
+	Buffer	peer;
+	int	newkeys;
+	int	flags;
+	char	*client_version_string;
+	char	*server_version_string;
+	int	(*check_host_key)(Key *hostkey);
+	Key	*(*load_host_key)(int type);
 };
 
-void	kex_derive_keys(Kex *k, u_char *hash, BIGNUM *shared_secret);
-void	packet_set_kex(Kex *k);
 Kex	*kex_start(char *proposal[PROPOSAL_MAX]);
 void	kex_send_newkeys(void);
+void	kex_send_kexinit(Kex *kex);
 void	kex_protocol_error(int type, int plen, void *ctxt);
+void	kex_derive_keys(Kex *k, u_char *hash, BIGNUM *shared_secret);
 
 void	kexdh(Kex *);
 void	kexgex(Kex *);
 
+Newkeys *kex_get_newkeys(int mode);
+
 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH)
 void	dump_digest(char *msg, u_char *digest, int len);
 #endif
diff --git a/packet.c b/packet.c
index 1a634cf..a4a0b05 100644
--- a/packet.c
+++ b/packet.c
@@ -37,7 +37,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: packet.c,v 1.56 2001/03/03 21:41:07 millert Exp $");
+RCSID("$OpenBSD: packet.c,v 1.57 2001/04/03 23:32:12 markus Exp $");
 
 #include "xmalloc.h"
 #include "buffer.h"
@@ -121,21 +121,9 @@
 int use_ssh2_packet_format = 0;
 
 /* Session key information for Encryption and MAC */
-Kex	*kex = NULL;
+Newkeys *newkeys[MODE_MAX];
 
 void
-packet_set_kex(Kex *k)
-{
-	if( k->mac[MODE_IN ].key == NULL ||
-	    k->enc[MODE_IN ].key == NULL ||
-	    k->enc[MODE_IN ].iv  == NULL ||
-	    k->mac[MODE_OUT].key == NULL ||
-	    k->enc[MODE_OUT].key == NULL ||
-	    k->enc[MODE_OUT].iv  == NULL)
-		fatal("bad KEX");
-	kex = k;
-}
-void
 clear_enc_keys(Enc *enc, int len)
 {
 	memset(enc->iv,  0, len);
@@ -150,6 +138,7 @@
 {
 	DBG(debug("use_ssh2_packet_format"));
 	use_ssh2_packet_format = 1;
+	newkeys[MODE_IN] = newkeys[MODE_OUT] = NULL;
 }
 
 /*
@@ -522,6 +511,41 @@
 	 */
 }
 
+void
+set_newkeys(int mode)
+{
+	Enc *enc;
+	Mac *mac;
+	Comp *comp;
+	CipherContext *cc;
+
+	debug("newkeys: mode %d", mode);
+
+	cc = (mode == MODE_OUT) ? &send_context : &receive_context;
+	if (newkeys[mode] != NULL) {
+		debug("newkeys: rekeying");
+		memset(cc, 0, sizeof(*cc));
+		// free old keys, reset compression cipher-contexts;
+	}
+	newkeys[mode] = kex_get_newkeys(mode);
+	if (newkeys[mode] == NULL)
+		fatal("newkeys: no keys for mode %d", mode);
+	enc  = &newkeys[mode]->enc;
+	mac  = &newkeys[mode]->mac;
+	comp = &newkeys[mode]->comp;
+	if (mac->md != NULL)
+		mac->enabled = 1;
+	DBG(debug("cipher_init_context: %d", mode));
+	cipher_init(cc, enc->cipher, enc->key, enc->cipher->key_len,
+	    enc->iv, enc->cipher->block_size);
+	clear_enc_keys(enc, enc->cipher->key_len);
+	if (comp->type != 0 && comp->enabled == 0) {
+		comp->enabled = 1;
+		if (! packet_compression)
+			packet_start_compression(6);
+	}
+}
+
 /*
  * Finalize packet in SSH2 format (compress, mac, encrypt, enqueue)
  */
@@ -540,10 +564,10 @@
 	Comp *comp = NULL;
 	int block_size;
 
-	if (kex != NULL) {
-		enc  = &kex->enc[MODE_OUT];
-		mac  = &kex->mac[MODE_OUT];
-		comp = &kex->comp[MODE_OUT];
+	if (newkeys[MODE_OUT] != NULL) {
+		enc  = &newkeys[MODE_OUT]->enc;
+		mac  = &newkeys[MODE_OUT]->mac;
+		comp = &newkeys[MODE_OUT]->comp;
 	}
 	block_size = enc ? enc->cipher->block_size : 8;
 
@@ -622,22 +646,8 @@
 		log("outgoing seqnr wraps around");
 	buffer_clear(&outgoing_packet);
 
-	if (type == SSH2_MSG_NEWKEYS) {
-		if (kex==NULL || mac==NULL || enc==NULL || comp==NULL)
-			fatal("packet_send2: no KEX");
-		if (mac->md != NULL)
-			mac->enabled = 1;
-		DBG(debug("cipher_init send_context"));
-		cipher_init(&send_context, enc->cipher,
-		    enc->key, enc->cipher->key_len,
-		    enc->iv, enc->cipher->block_size);
-		clear_enc_keys(enc, kex->we_need);
-		if (comp->type != 0 && comp->enabled == 0) {
-			comp->enabled = 1;
-			if (! packet_compression)
-				packet_start_compression(6);
-		}
-	}
+	if (type == SSH2_MSG_NEWKEYS)
+		set_newkeys(MODE_OUT);
 }
 
 void
@@ -833,10 +843,10 @@
 	Mac *mac   = NULL;
 	Comp *comp = NULL;
 
-	if (kex != NULL) {
-		enc  = &kex->enc[MODE_IN];
-		mac  = &kex->mac[MODE_IN];
-		comp = &kex->comp[MODE_IN];
+	if (newkeys[MODE_IN] != NULL) {
+		enc  = &newkeys[MODE_IN]->enc;
+		mac  = &newkeys[MODE_IN]->mac;
+		comp = &newkeys[MODE_IN]->comp;
 	}
 	maclen = mac && mac->enabled ? mac->mac_len : 0;
 	block_size = enc ? enc->cipher->block_size : 8;
@@ -930,22 +940,8 @@
 	/* extract packet type */
 	type = (u_char)buf[0];
 
-	if (type == SSH2_MSG_NEWKEYS) {
-		if (kex==NULL || mac==NULL || enc==NULL || comp==NULL)
-			fatal("packet_read_poll2: no KEX");
-		if (mac->md != NULL)
-			mac->enabled = 1;
-		DBG(debug("cipher_init receive_context"));
-		cipher_init(&receive_context, enc->cipher,
-		    enc->key, enc->cipher->key_len,
-		    enc->iv, enc->cipher->block_size);
-		clear_enc_keys(enc, kex->we_need);
-		if (comp->type != 0 && comp->enabled == 0) {
-			comp->enabled = 1;
-			if (! packet_compression)
-				packet_start_compression(6);
-		}
-	}
+	if (type == SSH2_MSG_NEWKEYS)
+		set_newkeys(MODE_IN);
 
 #ifdef PACKET_DEBUG
 	fprintf(stderr, "read/plain[%d]:\r\n", type);
@@ -1339,8 +1335,8 @@
 
 	have = buffer_len(&outgoing_packet);
 	debug2("packet_inject_ignore: current %d", have);
-	if (kex != NULL)
-	enc  = &kex->enc[MODE_OUT];
+	if (newkeys[MODE_OUT] != NULL)
+		enc  = &newkeys[MODE_OUT]->enc;
 	blocksize = enc ? enc->cipher->block_size : 8;
 	padlen = blocksize - (have % blocksize);
 	if (padlen < 4)
diff --git a/sshconnect2.c b/sshconnect2.c
index 4ed39a2..dd3f36b 100644
--- a/sshconnect2.c
+++ b/sshconnect2.c
@@ -23,7 +23,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: sshconnect2.c,v 1.61 2001/04/03 19:53:29 markus Exp $");
+RCSID("$OpenBSD: sshconnect2.c,v 1.62 2001/04/03 23:32:12 markus Exp $");
 
 #include <openssl/bn.h>
 #include <openssl/md5.h>
@@ -117,6 +117,9 @@
 	/* start key exchange */
 	dispatch_run(DISPATCH_BLOCK, &kex->newkeys, kex);
 
+	session_id2 = kex->session_id;
+	session_id2_len = kex->session_id_len;
+
 #ifdef DEBUG_KEXDH
 	/* send 1st encrypted/maced/compressed message */
 	packet_start(SSH2_MSG_IGNORE);
diff --git a/sshd.c b/sshd.c
index c546759..bdcae2c 100644
--- a/sshd.c
+++ b/sshd.c
@@ -40,7 +40,7 @@
  */
 
 #include "includes.h"
-RCSID("$OpenBSD: sshd.c,v 1.186 2001/04/03 19:53:29 markus Exp $");
+RCSID("$OpenBSD: sshd.c,v 1.187 2001/04/03 23:32:12 markus Exp $");
 
 #include <openssl/dh.h>
 #include <openssl/bn.h>
@@ -1434,6 +1434,9 @@
 	/* start key exchange */
 	dispatch_run(DISPATCH_BLOCK, &kex->newkeys, kex);
 
+	session_id2 = kex->session_id;
+	session_id2_len = kex->session_id_len;
+
 #ifdef DEBUG_KEXDH
 	/* send 1st encrypted/maced/compressed message */
 	packet_start(SSH2_MSG_IGNORE);