rxrpc: Move data_ready peer lookup into rxrpc_find_connection()

Move the peer lookup done in input.c by data_ready into
rxrpc_find_connection().

Signed-off-by: David Howells <dhowells@redhat.com>
diff --git a/net/rxrpc/conn_object.c b/net/rxrpc/conn_object.c
index 89bc648..1307138 100644
--- a/net/rxrpc/conn_object.c
+++ b/net/rxrpc/conn_object.c
@@ -68,52 +68,91 @@
  * packet
  */
 struct rxrpc_connection *rxrpc_find_connection(struct rxrpc_local *local,
-					       struct rxrpc_peer *peer,
 					       struct sk_buff *skb)
 {
 	struct rxrpc_connection *conn;
+	struct rxrpc_conn_proto k;
 	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
+	struct sockaddr_rxrpc srx;
+	struct rxrpc_peer *peer;
 	struct rb_node *p;
-	u32 epoch, cid;
 
 	_enter(",{%x,%x}", sp->hdr.cid, sp->hdr.flags);
 
-	read_lock_bh(&peer->conn_lock);
+	if (rxrpc_extract_addr_from_skb(&srx, skb) < 0)
+		goto not_found;
 
-	cid	= sp->hdr.cid & RXRPC_CIDMASK;
-	epoch	= sp->hdr.epoch;
+	/* We may have to handle mixing IPv4 and IPv6 */
+	if (srx.transport.family != local->srx.transport.family) {
+		pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n",
+				    srx.transport.family,
+				    local->srx.transport.family);
+		goto not_found;
+	}
+
+	k.epoch	= sp->hdr.epoch;
+	k.cid	= sp->hdr.cid & RXRPC_CIDMASK;
 
 	if (sp->hdr.flags & RXRPC_CLIENT_INITIATED) {
+		/* We need to look up service connections by the full protocol
+		 * parameter set.  We look up the peer first as an intermediate
+		 * step and then the connection from the peer's tree.
+		 */
+		peer = rxrpc_lookup_peer_rcu(local, &srx);
+		if (!peer)
+			goto not_found;
+
+		read_lock_bh(&peer->conn_lock);
+
 		p = peer->service_conns.rb_node;
 		while (p) {
 			conn = rb_entry(p, struct rxrpc_connection, service_node);
 
 			_debug("maybe %x", conn->proto.cid);
 
-			if (epoch < conn->proto.epoch)
+			if (k.epoch < conn->proto.epoch)
 				p = p->rb_left;
-			else if (epoch > conn->proto.epoch)
+			else if (k.epoch > conn->proto.epoch)
 				p = p->rb_right;
-			else if (cid < conn->proto.cid)
+			else if (k.cid < conn->proto.cid)
 				p = p->rb_left;
-			else if (cid > conn->proto.cid)
+			else if (k.cid > conn->proto.cid)
 				p = p->rb_right;
 			else
-				goto found;
+				goto found_service_conn;
 		}
+		read_unlock_bh(&peer->conn_lock);
 	} else {
-		conn = idr_find(&rxrpc_client_conn_ids, cid >> RXRPC_CIDSHIFT);
-		if (conn &&
-		    conn->proto.epoch == epoch &&
-		    conn->params.peer == peer)
-			goto found;
+		conn = idr_find(&rxrpc_client_conn_ids,
+				k.cid >> RXRPC_CIDSHIFT);
+		if (!conn ||
+		    conn->proto.epoch != k.epoch ||
+		    conn->params.local != local)
+			goto not_found;
+
+		peer = conn->params.peer;
+		switch (srx.transport.family) {
+		case AF_INET:
+			if (peer->srx.transport.sin.sin_port !=
+			    srx.transport.sin.sin_port ||
+			    peer->srx.transport.sin.sin_addr.s_addr !=
+			    srx.transport.sin.sin_addr.s_addr)
+				goto not_found;
+			break;
+		default:
+			BUG();
+		}
+
+		conn = rxrpc_get_connection_maybe(conn);
+		_leave(" = %p", conn);
+		return conn;
 	}
 
-	read_unlock_bh(&peer->conn_lock);
+not_found:
 	_leave(" = NULL");
 	return NULL;
 
-found:
+found_service_conn:
 	conn = rxrpc_get_connection_maybe(conn);
 	read_unlock_bh(&peer->conn_lock);
 	_leave(" = %p", conn);