upstream commit

move dispatch to struct ssh; ok djm@
diff --git a/dispatch.c b/dispatch.c
index 64bb809..70fa84f 100644
--- a/dispatch.c
+++ b/dispatch.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: dispatch.c,v 1.22 2008/10/31 15:05:34 stevesk Exp $ */
+/* $OpenBSD: dispatch.c,v 1.23 2015/01/19 20:07:45 markus Exp $ */
 /*
  * Copyright (c) 2000 Markus Friedl.  All rights reserved.
  *
@@ -36,69 +36,107 @@
 #include "dispatch.h"
 #include "packet.h"
 #include "compat.h"
+#include "ssherr.h"
 
-#define DISPATCH_MAX	255
-
-dispatch_fn *dispatch[DISPATCH_MAX];
-
-void
-dispatch_protocol_error(int type, u_int32_t seq, void *ctxt)
+int
+dispatch_protocol_error(int type, u_int32_t seq, void *ctx)
 {
+	struct ssh *ssh = active_state; /* XXX */
+	int r;
+
 	logit("dispatch_protocol_error: type %d seq %u", type, seq);
 	if (!compat20)
 		fatal("protocol error");
-	packet_start(SSH2_MSG_UNIMPLEMENTED);
-	packet_put_int(seq);
-	packet_send();
-	packet_write_wait();
+	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
+	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
+	    (r = sshpkt_send(ssh)) != 0)
+		fatal("%s: %s", __func__, ssh_err(r));
+	ssh_packet_write_wait(ssh);
+	return 0;
 }
-void
-dispatch_protocol_ignore(int type, u_int32_t seq, void *ctxt)
+
+int
+dispatch_protocol_ignore(int type, u_int32_t seq, void *ssh)
 {
 	logit("dispatch_protocol_ignore: type %d seq %u", type, seq);
+	return 0;
 }
+
 void
-dispatch_init(dispatch_fn *dflt)
+ssh_dispatch_init(struct ssh *ssh, dispatch_fn *dflt)
 {
 	u_int i;
 	for (i = 0; i < DISPATCH_MAX; i++)
-		dispatch[i] = dflt;
+		ssh->dispatch[i] = dflt;
 }
+
 void
-dispatch_range(u_int from, u_int to, dispatch_fn *fn)
+ssh_dispatch_range(struct ssh *ssh, u_int from, u_int to, dispatch_fn *fn)
 {
 	u_int i;
 
 	for (i = from; i <= to; i++) {
 		if (i >= DISPATCH_MAX)
 			break;
-		dispatch[i] = fn;
+		ssh->dispatch[i] = fn;
 	}
 }
-void
-dispatch_set(int type, dispatch_fn *fn)
-{
-	dispatch[type] = fn;
-}
-void
-dispatch_run(int mode, volatile sig_atomic_t *done, void *ctxt)
-{
-	for (;;) {
-		int type;
-		u_int32_t seqnr;
 
+void
+ssh_dispatch_set(struct ssh *ssh, int type, dispatch_fn *fn)
+{
+	ssh->dispatch[type] = fn;
+}
+
+int
+ssh_dispatch_run(struct ssh *ssh, int mode, volatile sig_atomic_t *done,
+    void *ctxt)
+{
+	int r;
+	u_char type;
+	u_int32_t seqnr;
+
+	for (;;) {
 		if (mode == DISPATCH_BLOCK) {
-			type = packet_read_seqnr(&seqnr);
+			r = ssh_packet_read_seqnr(ssh, &type, &seqnr);
+			if (r != 0)
+				return r;
 		} else {
-			type = packet_read_poll_seqnr(&seqnr);
+			r = ssh_packet_read_poll_seqnr(ssh, &type, &seqnr);
+			if (r != 0)
+				return r;
 			if (type == SSH_MSG_NONE)
-				return;
+				return 0;
 		}
-		if (type > 0 && type < DISPATCH_MAX && dispatch[type] != NULL)
-			(*dispatch[type])(type, seqnr, ctxt);
-		else
-			packet_disconnect("protocol error: rcvd type %d", type);
+		if (type > 0 && type < DISPATCH_MAX &&
+		    ssh->dispatch[type] != NULL) {
+			if (ssh->dispatch_skip_packets) {
+				debug2("skipped packet (type %u)", type);
+				ssh->dispatch_skip_packets--;
+				continue;
+			}
+			/* XXX 'ssh' will replace 'ctxt' later */
+			r = (*ssh->dispatch[type])(type, seqnr, ctxt);
+			if (r != 0)
+				return r;
+		} else {
+			r = sshpkt_disconnect(ssh,
+			    "protocol error: rcvd type %d", type);
+			if (r != 0)
+				return r;
+			return SSH_ERR_DISCONNECTED;
+		}
 		if (done != NULL && *done)
-			return;
+			return 0;
 	}
 }
+
+void
+ssh_dispatch_run_fatal(struct ssh *ssh, int mode, volatile sig_atomic_t *done,
+    void *ctxt)
+{
+	int r;
+
+	if ((r = ssh_dispatch_run(ssh, mode, done, ctxt)) != 0)
+		fatal("%s: %s", __func__, ssh_err(r));
+}