RPC: Ensure that we disconnect TCP socket when client requests error out

If we're part way through transmitting a TCP request, and the client
errors, then we need to disconnect and reconnect the TCP socket in order to
avoid confusing the server.

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
(cherry picked from 031a50c8b9ea82616abd4a4e18021a25848941ce commit)
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
index 4ba271f..d6409e7 100644
--- a/net/sunrpc/clnt.c
+++ b/net/sunrpc/clnt.c
@@ -921,26 +921,43 @@
 	task->tk_status = xprt_prepare_transmit(task);
 	if (task->tk_status != 0)
 		return;
+	task->tk_action = call_transmit_status;
 	/* Encode here so that rpcsec_gss can use correct sequence number. */
 	if (rpc_task_need_encode(task)) {
-		task->tk_rqstp->rq_bytes_sent = 0;
+		BUG_ON(task->tk_rqstp->rq_bytes_sent != 0);
 		call_encode(task);
 		/* Did the encode result in an error condition? */
 		if (task->tk_status != 0)
-			goto out_nosend;
+			return;
 	}
-	task->tk_action = call_transmit_status;
 	xprt_transmit(task);
 	if (task->tk_status < 0)
 		return;
-	if (!task->tk_msg.rpc_proc->p_decode) {
-		task->tk_action = rpc_exit_task;
-		rpc_wake_up_task(task);
-	}
-	return;
-out_nosend:
-	/* release socket write lock before attempting to handle error */
-	xprt_abort_transmit(task);
+	/*
+	 * On success, ensure that we call xprt_end_transmit() before sleeping
+	 * in order to allow access to the socket to other RPC requests.
+	 */
+	call_transmit_status(task);
+	if (task->tk_msg.rpc_proc->p_decode != NULL)
+		return;
+	task->tk_action = rpc_exit_task;
+	rpc_wake_up_task(task);
+}
+
+/*
+ * 5a.	Handle cleanup after a transmission
+ */
+static void
+call_transmit_status(struct rpc_task *task)
+{
+	task->tk_action = call_status;
+	/*
+	 * Special case: if we've been waiting on the socket's write_space()
+	 * callback, then don't call xprt_end_transmit().
+	 */
+	if (task->tk_status == -EAGAIN)
+		return;
+	xprt_end_transmit(task);
 	rpc_task_force_reencode(task);
 }
 
@@ -992,18 +1009,7 @@
 }
 
 /*
- * 6a.	Handle transmission errors.
- */
-static void
-call_transmit_status(struct rpc_task *task)
-{
-	if (task->tk_status != -EAGAIN)
-		rpc_task_force_reencode(task);
-	call_status(task);
-}
-
-/*
- * 6b.	Handle RPC timeout
+ * 6a.	Handle RPC timeout
  * 	We do not release the request slot, so we keep using the
  *	same XID for all retransmits.
  */
diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c
index 313b68d..e8c2bc4 100644
--- a/net/sunrpc/xprt.c
+++ b/net/sunrpc/xprt.c
@@ -707,12 +707,9 @@
 	return err;
 }
 
-void
-xprt_abort_transmit(struct rpc_task *task)
+void xprt_end_transmit(struct rpc_task *task)
 {
-	struct rpc_xprt	*xprt = task->tk_xprt;
-
-	xprt_release_write(xprt, task);
+	xprt_release_write(task->tk_xprt, task);
 }
 
 /**
@@ -761,8 +758,6 @@
 			task->tk_status = -ENOTCONN;
 		else if (!req->rq_received)
 			rpc_sleep_on(&xprt->pending, task, NULL, xprt_timer);
-
-		xprt->ops->release_xprt(xprt, task);
 		spin_unlock_bh(&xprt->transport_lock);
 		return;
 	}
@@ -772,18 +767,8 @@
 	 *	 schedq, and being picked up by a parallel run of rpciod().
 	 */
 	task->tk_status = status;
-
-	switch (status) {
-	case -ECONNREFUSED:
+	if (status == -ECONNREFUSED)
 		rpc_sleep_on(&xprt->sending, task, NULL, NULL);
-	case -EAGAIN:
-	case -ENOTCONN:
-		return;
-	default:
-		break;
-	}
-	xprt_release_write(xprt, task);
-	return;
 }
 
 static inline void do_xprt_reserve(struct rpc_task *task)
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index ee678ed..441bd53 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -414,6 +414,33 @@
 }
 
 /**
+ * xs_tcp_release_xprt - clean up after a tcp transmission
+ * @xprt: transport
+ * @task: rpc task
+ *
+ * This cleans up if an error causes us to abort the transmission of a request.
+ * In this case, the socket may need to be reset in order to avoid confusing
+ * the server.
+ */
+static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task)
+{
+	struct rpc_rqst *req;
+
+	if (task != xprt->snd_task)
+		return;
+	if (task == NULL)
+		goto out_release;
+	req = task->tk_rqstp;
+	if (req->rq_bytes_sent == 0)
+		goto out_release;
+	if (req->rq_bytes_sent == req->rq_snd_buf.len)
+		goto out_release;
+	set_bit(XPRT_CLOSE_WAIT, &task->tk_xprt->state);
+out_release:
+	xprt_release_xprt(xprt, task);
+}
+
+/**
  * xs_close - close a socket
  * @xprt: transport
  *
@@ -1250,7 +1277,7 @@
 
 static struct rpc_xprt_ops xs_tcp_ops = {
 	.reserve_xprt		= xprt_reserve_xprt,
-	.release_xprt		= xprt_release_xprt,
+	.release_xprt		= xs_tcp_release_xprt,
 	.set_port		= xs_set_port,
 	.connect		= xs_connect,
 	.buf_alloc		= rpc_malloc,