SUNRPC: Don't change the RPCSEC_GSS context on a credential that is in use

When a server rejects our credential with an AUTH_REJECTEDCRED or similar,
we need to refresh the credential and then retry the request.
However, we do want to allow any requests that are in flight to finish
executing, so that we can at least attempt to process the replies that
depend on this instance of the credential.

The solution is to ensure that gss_refresh() looks up an entirely new
RPCSEC_GSS credential instead of attempting to create a context for the
existing invalid credential.

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index 6f1b4e2..621c07f 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -114,28 +114,14 @@
 gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
 {
 	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
-	struct gss_cl_ctx *old;
 
-	old = gss_cred->gc_ctx;
+	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
+		return;
 	gss_get_ctx(ctx);
 	rcu_assign_pointer(gss_cred->gc_ctx, ctx);
 	set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+	smp_mb__before_clear_bit();
 	clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
-	if (old)
-		gss_put_ctx(old);
-}
-
-static int
-gss_cred_is_uptodate_ctx(struct rpc_cred *cred)
-{
-	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
-	int res = 0;
-
-	rcu_read_lock();
-	if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) && gss_cred->gc_ctx)
-		res = 1;
-	rcu_read_unlock();
-	return res;
 }
 
 static const void *
@@ -857,15 +843,12 @@
 {
 	struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
 
-	/*
-	 * If the searchflags have set RPCAUTH_LOOKUP_NEW, then
-	 * we don't really care if the credential has expired or not,
-	 * since the caller should be prepared to reinitialise it.
-	 */
-	if ((flags & RPCAUTH_LOOKUP_NEW) && test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
+	if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
 		goto out;
 	/* Don't match with creds that have expired. */
-	if (gss_cred->gc_ctx && time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
+	if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
+		return 0;
+	if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
 		return 0;
 out:
 	if (acred->machine_cred != gss_cred->gc_machine_cred)
@@ -933,16 +916,48 @@
 	return NULL;
 }
 
+static int gss_renew_cred(struct rpc_task *task)
+{
+	struct rpc_cred *oldcred = task->tk_msg.rpc_cred;
+	struct gss_cred *gss_cred = container_of(oldcred,
+						 struct gss_cred,
+						 gc_base);
+	struct rpc_auth *auth = oldcred->cr_auth;
+	struct auth_cred acred = {
+		.uid = oldcred->cr_uid,
+		.machine_cred = gss_cred->gc_machine_cred,
+	};
+	struct rpc_cred *new;
+
+	new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
+	if (IS_ERR(new))
+		return PTR_ERR(new);
+	task->tk_msg.rpc_cred = new;
+	put_rpccred(oldcred);
+	return 0;
+}
+
 /*
 * Refresh credentials. XXX - finish
 */
 static int
 gss_refresh(struct rpc_task *task)
 {
+	struct rpc_cred *cred = task->tk_msg.rpc_cred;
+	int ret = 0;
 
-	if (!gss_cred_is_uptodate_ctx(task->tk_msg.rpc_cred))
-		return gss_refresh_upcall(task);
-	return 0;
+	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
+			!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
+		ret = gss_renew_cred(task);
+		if (ret < 0)
+			goto out;
+		cred = task->tk_msg.rpc_cred;
+	}
+
+	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
+		ret = gss_refresh_upcall(task);
+out:
+	return ret;
 }
 
 /* Dummy refresh routine: used only when destroying the context */