SUNRPC: Enforce atomic updates of rpc_cred->cr_flags
Convert to the use of atomic bitops...
Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/include/linux/sunrpc/auth.h b/include/linux/sunrpc/auth.h
index d5bfc67..8586503 100644
--- a/include/linux/sunrpc/auth.h
+++ b/include/linux/sunrpc/auth.h
@@ -36,19 +36,19 @@
struct hlist_node cr_hash; /* hash chain */
struct rpc_auth * cr_auth;
const struct rpc_credops *cr_ops;
- unsigned long cr_expire; /* when to gc */
- atomic_t cr_count; /* ref count */
- unsigned short cr_flags; /* various flags */
#ifdef RPC_DEBUG
unsigned long cr_magic; /* 0x0f4aa4f0 */
#endif
+ unsigned long cr_expire; /* when to gc */
+ unsigned long cr_flags; /* various flags */
+ atomic_t cr_count; /* ref count */
uid_t cr_uid;
/* per-flavor data */
};
-#define RPCAUTH_CRED_NEW 0x0001
-#define RPCAUTH_CRED_UPTODATE 0x0002
+#define RPCAUTH_CRED_NEW 0
+#define RPCAUTH_CRED_UPTODATE 1
#define RPCAUTH_CRED_MAGIC 0x0f4aa4f0
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index 2156327..4d7c78b 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -190,8 +190,8 @@
if (atomic_read(&cred->cr_count) != 1)
return;
if (time_after(jiffies, cred->cr_expire + auth->au_credcache->expire))
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
- if (!(cred->cr_flags & RPCAUTH_CRED_UPTODATE)) {
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0) {
__hlist_del(&cred->cr_hash);
hlist_add_head(&cred->cr_hash, free);
}
@@ -267,7 +267,7 @@
if (!IS_ERR(new))
goto retry;
cred = new;
- } else if ((cred->cr_flags & RPCAUTH_CRED_NEW)
+ } else if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)
&& cred->cr_ops->cr_init != NULL
&& !(flags & RPCAUTH_LOOKUP_NEW)) {
int res = cred->cr_ops->cr_init(auth, cred);
@@ -440,17 +440,19 @@
void
rpcauth_invalcred(struct rpc_task *task)
{
+ struct rpc_cred *cred = task->tk_msg.rpc_cred;
+
dprintk("RPC: %5u invalidating %s cred %p\n",
- task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
- spin_lock(&rpc_credcache_lock);
- if (task->tk_msg.rpc_cred)
- task->tk_msg.rpc_cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
- spin_unlock(&rpc_credcache_lock);
+ task->tk_pid, task->tk_auth->au_ops->au_name, cred);
+ if (cred)
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
}
int
rpcauth_uptodatecred(struct rpc_task *task)
{
- return !(task->tk_msg.rpc_cred) ||
- (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_UPTODATE);
+ struct rpc_cred *cred = task->tk_msg.rpc_cred;
+
+ return cred == NULL ||
+ test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
}
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index 766de0a..55c47ae 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -114,8 +114,8 @@
write_lock(&gss_ctx_lock);
old = gss_cred->gc_ctx;
gss_cred->gc_ctx = ctx;
- cred->cr_flags |= RPCAUTH_CRED_UPTODATE;
- cred->cr_flags &= ~RPCAUTH_CRED_NEW;
+ set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
+ clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
write_unlock(&gss_ctx_lock);
if (old)
gss_put_ctx(old);
@@ -128,7 +128,7 @@
int res = 0;
read_lock(&gss_ctx_lock);
- if ((cred->cr_flags & RPCAUTH_CRED_UPTODATE) && gss_cred->gc_ctx)
+ if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) && gss_cred->gc_ctx)
res = 1;
read_unlock(&gss_ctx_lock);
return res;
@@ -732,7 +732,7 @@
* Note: in order to force a call to call_refresh(), we deliberately
* fail to flag the credential as RPCAUTH_CRED_UPTODATE.
*/
- cred->gc_base.cr_flags = RPCAUTH_CRED_NEW;
+ cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
cred->gc_service = gss_auth->service;
return &cred->gc_base;
@@ -764,7 +764,7 @@
* we don't really care if the credential has expired or not,
* since the caller should be prepared to reinitialise it.
*/
- if ((flags & RPCAUTH_LOOKUP_NEW) && (rc->cr_flags & RPCAUTH_CRED_NEW))
+ if ((flags & RPCAUTH_LOOKUP_NEW) && test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
goto out;
/* Don't match with creds that have expired. */
if (gss_cred->gc_ctx && time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
@@ -820,7 +820,7 @@
mic.data = (u8 *)(p + 1);
maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
} else if (maj_stat != 0) {
printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
goto out_put_ctx;
@@ -873,7 +873,7 @@
maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
if (maj_stat == GSS_S_CONTEXT_EXPIRED)
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
if (maj_stat)
goto out_bad;
/* We leave it to unwrap to calculate au_rslack. For now we just
@@ -927,7 +927,7 @@
maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
status = -EIO; /* XXX? */
if (maj_stat == GSS_S_CONTEXT_EXPIRED)
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
else if (maj_stat)
return status;
q = xdr_encode_opaque(p, NULL, mic.len);
@@ -1026,7 +1026,7 @@
/* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
* done anyway, so it's safe to put the request on the wire: */
if (maj_stat == GSS_S_CONTEXT_EXPIRED)
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
else if (maj_stat)
return status;
@@ -1113,7 +1113,7 @@
maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
if (maj_stat == GSS_S_CONTEXT_EXPIRED)
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
if (maj_stat != GSS_S_COMPLETE)
return status;
return 0;
@@ -1138,7 +1138,7 @@
maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
if (maj_stat == GSS_S_CONTEXT_EXPIRED)
- cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
+ clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
if (maj_stat != GSS_S_COMPLETE)
return status;
if (ntohl(*(*p)++) != rqstp->rq_seqno)
diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c
index fe9b6aa..6c905fb 100644
--- a/net/sunrpc/auth_null.c
+++ b/net/sunrpc/auth_null.c
@@ -76,7 +76,7 @@
static int
nul_refresh(struct rpc_task *task)
{
- task->tk_msg.rpc_cred->cr_flags |= RPCAUTH_CRED_UPTODATE;
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags);
return 0;
}
@@ -136,7 +136,7 @@
.cr_auth = &null_auth,
.cr_ops = &null_credops,
.cr_count = ATOMIC_INIT(1),
- .cr_flags = RPCAUTH_CRED_UPTODATE,
+ .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
#ifdef RPC_DEBUG
.cr_magic = RPCAUTH_CRED_MAGIC,
#endif
diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c
index f17dabb..29d50ff 100644
--- a/net/sunrpc/auth_unix.c
+++ b/net/sunrpc/auth_unix.c
@@ -72,7 +72,7 @@
return ERR_PTR(-ENOMEM);
rpcauth_init_cred(&cred->uc_base, acred, auth, &unix_credops);
- cred->uc_base.cr_flags = RPCAUTH_CRED_UPTODATE;
+ cred->uc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE;
if (flags & RPCAUTH_LOOKUP_ROOTCREDS) {
cred->uc_uid = 0;
cred->uc_gid = 0;
@@ -172,7 +172,7 @@
static int
unx_refresh(struct rpc_task *task)
{
- task->tk_msg.rpc_cred->cr_flags |= RPCAUTH_CRED_UPTODATE;
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags);
return 0;
}