KEYS: Correctly destroy key payloads when their keytype is removed

unregister_key_type() has code to mark a key as dead and make it unavailable in
one loop and then destroy all those unavailable key payloads in the next loop.
However, the loop to mark keys dead renders the key undetectable to the second
loop by changing the key type pointer also.

Fix this by the following means:

 (1) The key code has two garbage collectors: one deletes unreferenced keys and
     the other alters keyrings to delete links to old dead, revoked and expired
     keys.  They can end up holding each other up as both want to scan the key
     serial tree under spinlock.  Combine these into a single routine.

 (2) Move the dead key marking, dead link removal and dead key removal into the
     garbage collector as a three phase process running over the three cycles
     of the normal garbage collection procedure.  This is tracked by the
     KEY_GC_REAPING_DEAD_1, _2 and _3 state flags.

     unregister_key_type() then just unlinks the key type from the list, wakes
     up the garbage collector and waits for the third phase to complete.

 (3) Downgrade the key types sem in unregister_key_type() once it has deleted
     the key type from the list so that it doesn't block the keyctl() syscall.

 (4) Dead keys that cannot be simply removed in the third phase have their
     payloads destroyed with the key's semaphore write-locked to prevent
     interference by the keyctl() syscall.  There should be no in-kernel users
     of dead keys of that type by the point of unregistration, though keyctl()
     may be holding a reference.

 (5) Only perform timer recalculation in the GC if the timer actually expired.
     If it didn't, we'll get another cycle when it goes off - and if the key
     that actually triggered it has been removed, it's not a problem.

 (6) Only garbage collect link if the timer expired or if we're doing dead key
     clean up phase 2.

 (7) As only key_garbage_collector() is permitted to use rb_erase() on the key
     serial tree, it doesn't need to revalidate its cursor after dropping the
     spinlock as the node the cursor points to must still exist in the tree.

 (8) Drop the spinlock in the GC if there is contention on it or if we need to
     reschedule.  After dealing with that, get the spinlock again and resume
     scanning.

This has been tested in the following ways:

 (1) Run the keyutils testsuite against it.

 (2) Using the AF_RXRPC and RxKAD modules to test keytype removal:

     Load the rxrpc_s key type:

	# insmod /tmp/af-rxrpc.ko
	# insmod /tmp/rxkad.ko

     Create a key (http://people.redhat.com/~dhowells/rxrpc/listen.c):

	# /tmp/listen &
	[1] 8173

     Find the key:

	# grep rxrpc_s /proc/keys
	091086e1 I--Q--     1 perm 39390000     0     0 rxrpc_s   52:2

     Link it to a session keyring, preferably one with a higher serial number:

	# keyctl link 0x20e36251 @s

     Kill the process (the key should remain as it's linked to another place):

	# fg
	/tmp/listen
	^C

     Remove the key type:

	rmmod rxkad
	rmmod af-rxrpc

     This can be made a more effective test by altering the following part of
     the patch:

	if (unlikely(gc_state & KEY_GC_REAPING_DEAD_2)) {
		/* Make sure everyone revalidates their keys if we marked a
		 * bunch as being dead and make sure all keyring ex-payloads
		 * are destroyed.
		 */
		kdebug("dead sync");
		synchronize_rcu();

     To call synchronize_rcu() in GC phase 1 instead.  That causes that the
     keyring's old payload content to hang around longer until it's RCU
     destroyed - which usually happens after GC phase 3 is complete.  This
     allows the destroy_dead_key branch to be tested.

Reported-by: Benjamin Coddington <bcodding@gmail.com>
Signed-off-by: David Howells <dhowells@redhat.com>
Signed-off-by: James Morris <jmorris@namei.org>
diff --git a/security/keys/gc.c b/security/keys/gc.c
index d67e88b..bf4d8da 100644
--- a/security/keys/gc.c
+++ b/security/keys/gc.c
@@ -1,6 +1,6 @@
 /* Key garbage collector
  *
- * Copyright (C) 2009 Red Hat, Inc. All Rights Reserved.
+ * Copyright (C) 2009-2011 Red Hat, Inc. All Rights Reserved.
  * Written by David Howells (dhowells@redhat.com)
  *
  * This program is free software; you can redistribute it and/or
@@ -23,21 +23,31 @@
 /*
  * Reaper for unused keys.
  */
-static void key_gc_unused_keys(struct work_struct *work);
-DECLARE_WORK(key_gc_unused_work, key_gc_unused_keys);
+static void key_garbage_collector(struct work_struct *work);
+DECLARE_WORK(key_gc_work, key_garbage_collector);
 
 /*
  * Reaper for links from keyrings to dead keys.
  */
 static void key_gc_timer_func(unsigned long);
-static void key_gc_dead_links(struct work_struct *);
 static DEFINE_TIMER(key_gc_timer, key_gc_timer_func, 0, 0);
-static DECLARE_WORK(key_gc_work, key_gc_dead_links);
-static key_serial_t key_gc_cursor; /* the last key the gc considered */
-static bool key_gc_again;
-static unsigned long key_gc_executing;
+
 static time_t key_gc_next_run = LONG_MAX;
-static time_t key_gc_new_timer;
+static struct key_type *key_gc_dead_keytype;
+
+static unsigned long key_gc_flags;
+#define KEY_GC_KEY_EXPIRED	0	/* A key expired and needs unlinking */
+#define KEY_GC_REAP_KEYTYPE	1	/* A keytype is being unregistered */
+#define KEY_GC_REAPING_KEYTYPE	2	/* Cleared when keytype reaped */
+
+
+/*
+ * Any key whose type gets unregistered will be re-typed to this if it can't be
+ * immediately unlinked.
+ */
+struct key_type key_type_dead = {
+	.name = "dead",
+};
 
 /*
  * Schedule a garbage collection run.
@@ -50,31 +60,75 @@
 
 	kenter("%ld", gc_at - now);
 
-	if (gc_at <= now) {
+	if (gc_at <= now || test_bit(KEY_GC_REAP_KEYTYPE, &key_gc_flags)) {
+		kdebug("IMMEDIATE");
 		queue_work(system_nrt_wq, &key_gc_work);
 	} else if (gc_at < key_gc_next_run) {
+		kdebug("DEFERRED");
+		key_gc_next_run = gc_at;
 		expires = jiffies + (gc_at - now) * HZ;
 		mod_timer(&key_gc_timer, expires);
 	}
 }
 
 /*
- * The garbage collector timer kicked off
+ * Some key's cleanup time was met after it expired, so we need to get the
+ * reaper to go through a cycle finding expired keys.
  */
 static void key_gc_timer_func(unsigned long data)
 {
 	kenter("");
 	key_gc_next_run = LONG_MAX;
+	set_bit(KEY_GC_KEY_EXPIRED, &key_gc_flags);
 	queue_work(system_nrt_wq, &key_gc_work);
 }
 
 /*
+ * wait_on_bit() sleep function for uninterruptible waiting
+ */
+static int key_gc_wait_bit(void *flags)
+{
+	schedule();
+	return 0;
+}
+
+/*
+ * Reap keys of dead type.
+ *
+ * We use three flags to make sure we see three complete cycles of the garbage
+ * collector: the first to mark keys of that type as being dead, the second to
+ * collect dead links and the third to clean up the dead keys.  We have to be
+ * careful as there may already be a cycle in progress.
+ *
+ * The caller must be holding key_types_sem.
+ */
+void key_gc_keytype(struct key_type *ktype)
+{
+	kenter("%s", ktype->name);
+
+	key_gc_dead_keytype = ktype;
+	set_bit(KEY_GC_REAPING_KEYTYPE, &key_gc_flags);
+	smp_mb();
+	set_bit(KEY_GC_REAP_KEYTYPE, &key_gc_flags);
+
+	kdebug("schedule");
+	queue_work(system_nrt_wq, &key_gc_work);
+
+	kdebug("sleep");
+	wait_on_bit(&key_gc_flags, KEY_GC_REAPING_KEYTYPE, key_gc_wait_bit,
+		    TASK_UNINTERRUPTIBLE);
+
+	key_gc_dead_keytype = NULL;
+	kleave("");
+}
+
+/*
  * Garbage collect pointers from a keyring.
  *
- * Return true if we altered the keyring.
+ * Not called with any locks held.  The keyring's key struct will not be
+ * deallocated under us as only our caller may deallocate it.
  */
-static bool key_gc_keyring(struct key *keyring, time_t limit)
-	__releases(key_serial_lock)
+static void key_gc_keyring(struct key *keyring, time_t limit)
 {
 	struct keyring_list *klist;
 	struct key *key;
@@ -101,168 +155,21 @@
 unlock_dont_gc:
 	rcu_read_unlock();
 dont_gc:
-	kleave(" = false");
-	return false;
+	kleave(" [no gc]");
+	return;
 
 do_gc:
 	rcu_read_unlock();
-	key_gc_cursor = keyring->serial;
-	key_get(keyring);
-	spin_unlock(&key_serial_lock);
+
 	keyring_gc(keyring, limit);
-	key_put(keyring);
-	kleave(" = true");
-	return true;
+	kleave(" [gc]");
 }
 
 /*
- * Garbage collector for links to dead keys.
- *
- * This involves scanning the keyrings for dead, expired and revoked keys that
- * have overstayed their welcome
+ * Garbage collect an unreferenced, detached key
  */
-static void key_gc_dead_links(struct work_struct *work)
+static noinline void key_gc_unused_key(struct key *key)
 {
-	struct rb_node *rb;
-	key_serial_t cursor;
-	struct key *key, *xkey;
-	time_t new_timer = LONG_MAX, limit, now;
-
-	now = current_kernel_time().tv_sec;
-	kenter("[%x,%ld]", key_gc_cursor, key_gc_new_timer - now);
-
-	if (test_and_set_bit(0, &key_gc_executing)) {
-		key_schedule_gc(current_kernel_time().tv_sec + 1);
-		kleave(" [busy; deferring]");
-		return;
-	}
-
-	limit = now;
-	if (limit > key_gc_delay)
-		limit -= key_gc_delay;
-	else
-		limit = key_gc_delay;
-
-	spin_lock(&key_serial_lock);
-
-	if (unlikely(RB_EMPTY_ROOT(&key_serial_tree))) {
-		spin_unlock(&key_serial_lock);
-		clear_bit(0, &key_gc_executing);
-		return;
-	}
-
-	cursor = key_gc_cursor;
-	if (cursor < 0)
-		cursor = 0;
-	if (cursor > 0)
-		new_timer = key_gc_new_timer;
-	else
-		key_gc_again = false;
-
-	/* find the first key above the cursor */
-	key = NULL;
-	rb = key_serial_tree.rb_node;
-	while (rb) {
-		xkey = rb_entry(rb, struct key, serial_node);
-		if (cursor < xkey->serial) {
-			key = xkey;
-			rb = rb->rb_left;
-		} else if (cursor > xkey->serial) {
-			rb = rb->rb_right;
-		} else {
-			rb = rb_next(rb);
-			if (!rb)
-				goto reached_the_end;
-			key = rb_entry(rb, struct key, serial_node);
-			break;
-		}
-	}
-
-	if (!key)
-		goto reached_the_end;
-
-	/* trawl through the keys looking for keyrings */
-	for (;;) {
-		if (key->expiry > limit && key->expiry < new_timer) {
-			kdebug("will expire %x in %ld",
-			       key_serial(key), key->expiry - limit);
-			new_timer = key->expiry;
-		}
-
-		if (key->type == &key_type_keyring &&
-		    key_gc_keyring(key, limit))
-			/* the gc had to release our lock so that the keyring
-			 * could be modified, so we have to get it again */
-			goto gc_released_our_lock;
-
-		rb = rb_next(&key->serial_node);
-		if (!rb)
-			goto reached_the_end;
-		key = rb_entry(rb, struct key, serial_node);
-	}
-
-gc_released_our_lock:
-	kdebug("gc_released_our_lock");
-	key_gc_new_timer = new_timer;
-	key_gc_again = true;
-	clear_bit(0, &key_gc_executing);
-	queue_work(system_nrt_wq, &key_gc_work);
-	kleave(" [continue]");
-	return;
-
-	/* when we reach the end of the run, we set the timer for the next one */
-reached_the_end:
-	kdebug("reached_the_end");
-	spin_unlock(&key_serial_lock);
-	key_gc_new_timer = new_timer;
-	key_gc_cursor = 0;
-	clear_bit(0, &key_gc_executing);
-
-	if (key_gc_again) {
-		/* there may have been a key that expired whilst we were
-		 * scanning, so if we discarded any links we should do another
-		 * scan */
-		new_timer = now + 1;
-		key_schedule_gc(new_timer);
-	} else if (new_timer < LONG_MAX) {
-		new_timer += key_gc_delay;
-		key_schedule_gc(new_timer);
-	}
-	kleave(" [end]");
-}
-
-/*
- * Garbage collector for unused keys.
- *
- * This is done in process context so that we don't have to disable interrupts
- * all over the place.  key_put() schedules this rather than trying to do the
- * cleanup itself, which means key_put() doesn't have to sleep.
- */
-static void key_gc_unused_keys(struct work_struct *work)
-{
-	struct rb_node *_n;
-	struct key *key;
-
-go_again:
-	/* look for a dead key in the tree */
-	spin_lock(&key_serial_lock);
-
-	for (_n = rb_first(&key_serial_tree); _n; _n = rb_next(_n)) {
-		key = rb_entry(_n, struct key, serial_node);
-
-		if (atomic_read(&key->usage) == 0)
-			goto found_dead_key;
-	}
-
-	spin_unlock(&key_serial_lock);
-	return;
-
-found_dead_key:
-	/* we found a dead key - once we've removed it from the tree, we can
-	 * drop the lock */
-	rb_erase(&key->serial_node, &key_serial_tree);
-	spin_unlock(&key_serial_lock);
-
 	key_check(key);
 
 	security_key_free(key);
@@ -291,7 +198,191 @@
 	key->magic = KEY_DEBUG_MAGIC_X;
 #endif
 	kmem_cache_free(key_jar, key);
+}
 
-	/* there may, of course, be more than one key to destroy */
-	goto go_again;
+/*
+ * Garbage collector for unused keys.
+ *
+ * This is done in process context so that we don't have to disable interrupts
+ * all over the place.  key_put() schedules this rather than trying to do the
+ * cleanup itself, which means key_put() doesn't have to sleep.
+ */
+static void key_garbage_collector(struct work_struct *work)
+{
+	static u8 gc_state;		/* Internal persistent state */
+#define KEY_GC_REAP_AGAIN	0x01	/* - Need another cycle */
+#define KEY_GC_REAPING_LINKS	0x02	/* - We need to reap links */
+#define KEY_GC_SET_TIMER	0x04	/* - We need to restart the timer */
+#define KEY_GC_REAPING_DEAD_1	0x10	/* - We need to mark dead keys */
+#define KEY_GC_REAPING_DEAD_2	0x20	/* - We need to reap dead key links */
+#define KEY_GC_REAPING_DEAD_3	0x40	/* - We need to reap dead keys */
+#define KEY_GC_FOUND_DEAD_KEY	0x80	/* - We found at least one dead key */
+
+	struct rb_node *cursor;
+	struct key *key;
+	time_t new_timer, limit;
+
+	kenter("[%lx,%x]", key_gc_flags, gc_state);
+
+	limit = current_kernel_time().tv_sec;
+	if (limit > key_gc_delay)
+		limit -= key_gc_delay;
+	else
+		limit = key_gc_delay;
+
+	/* Work out what we're going to be doing in this pass */
+	gc_state &= KEY_GC_REAPING_DEAD_1 | KEY_GC_REAPING_DEAD_2;
+	gc_state <<= 1;
+	if (test_and_clear_bit(KEY_GC_KEY_EXPIRED, &key_gc_flags))
+		gc_state |= KEY_GC_REAPING_LINKS | KEY_GC_SET_TIMER;
+
+	if (test_and_clear_bit(KEY_GC_REAP_KEYTYPE, &key_gc_flags))
+		gc_state |= KEY_GC_REAPING_DEAD_1;
+	kdebug("new pass %x", gc_state);
+
+	new_timer = LONG_MAX;
+
+	/* As only this function is permitted to remove things from the key
+	 * serial tree, if cursor is non-NULL then it will always point to a
+	 * valid node in the tree - even if lock got dropped.
+	 */
+	spin_lock(&key_serial_lock);
+	cursor = rb_first(&key_serial_tree);
+
+continue_scanning:
+	while (cursor) {
+		key = rb_entry(cursor, struct key, serial_node);
+		cursor = rb_next(cursor);
+
+		if (atomic_read(&key->usage) == 0)
+			goto found_unreferenced_key;
+
+		if (unlikely(gc_state & KEY_GC_REAPING_DEAD_1)) {
+			if (key->type == key_gc_dead_keytype) {
+				gc_state |= KEY_GC_FOUND_DEAD_KEY;
+				set_bit(KEY_FLAG_DEAD, &key->flags);
+				key->perm = 0;
+				goto skip_dead_key;
+			}
+		}
+
+		if (gc_state & KEY_GC_SET_TIMER) {
+			if (key->expiry > limit && key->expiry < new_timer) {
+				kdebug("will expire %x in %ld",
+				       key_serial(key), key->expiry - limit);
+				new_timer = key->expiry;
+			}
+		}
+
+		if (unlikely(gc_state & KEY_GC_REAPING_DEAD_2))
+			if (key->type == key_gc_dead_keytype)
+				gc_state |= KEY_GC_FOUND_DEAD_KEY;
+
+		if ((gc_state & KEY_GC_REAPING_LINKS) ||
+		    unlikely(gc_state & KEY_GC_REAPING_DEAD_2)) {
+			if (key->type == &key_type_keyring)
+				goto found_keyring;
+		}
+
+		if (unlikely(gc_state & KEY_GC_REAPING_DEAD_3))
+			if (key->type == key_gc_dead_keytype)
+				goto destroy_dead_key;
+
+	skip_dead_key:
+		if (spin_is_contended(&key_serial_lock) || need_resched())
+			goto contended;
+	}
+
+contended:
+	spin_unlock(&key_serial_lock);
+
+maybe_resched:
+	if (cursor) {
+		cond_resched();
+		spin_lock(&key_serial_lock);
+		goto continue_scanning;
+	}
+
+	/* We've completed the pass.  Set the timer if we need to and queue a
+	 * new cycle if necessary.  We keep executing cycles until we find one
+	 * where we didn't reap any keys.
+	 */
+	kdebug("pass complete");
+
+	if (gc_state & KEY_GC_SET_TIMER && new_timer != (time_t)LONG_MAX) {
+		new_timer += key_gc_delay;
+		key_schedule_gc(new_timer);
+	}
+
+	if (unlikely(gc_state & KEY_GC_REAPING_DEAD_2)) {
+		/* Make sure everyone revalidates their keys if we marked a
+		 * bunch as being dead and make sure all keyring ex-payloads
+		 * are destroyed.
+		 */
+		kdebug("dead sync");
+		synchronize_rcu();
+	}
+
+	if (unlikely(gc_state & (KEY_GC_REAPING_DEAD_1 |
+				 KEY_GC_REAPING_DEAD_2))) {
+		if (!(gc_state & KEY_GC_FOUND_DEAD_KEY)) {
+			/* No remaining dead keys: short circuit the remaining
+			 * keytype reap cycles.
+			 */
+			kdebug("dead short");
+			gc_state &= ~(KEY_GC_REAPING_DEAD_1 | KEY_GC_REAPING_DEAD_2);
+			gc_state |= KEY_GC_REAPING_DEAD_3;
+		} else {
+			gc_state |= KEY_GC_REAP_AGAIN;
+		}
+	}
+
+	if (unlikely(gc_state & KEY_GC_REAPING_DEAD_3)) {
+		kdebug("dead wake");
+		smp_mb();
+		clear_bit(KEY_GC_REAPING_KEYTYPE, &key_gc_flags);
+		wake_up_bit(&key_gc_flags, KEY_GC_REAPING_KEYTYPE);
+	}
+
+	if (gc_state & KEY_GC_REAP_AGAIN)
+		queue_work(system_nrt_wq, &key_gc_work);
+	kleave(" [end %x]", gc_state);
+	return;
+
+	/* We found an unreferenced key - once we've removed it from the tree,
+	 * we can safely drop the lock.
+	 */
+found_unreferenced_key:
+	kdebug("unrefd key %d", key->serial);
+	rb_erase(&key->serial_node, &key_serial_tree);
+	spin_unlock(&key_serial_lock);
+
+	key_gc_unused_key(key);
+	gc_state |= KEY_GC_REAP_AGAIN;
+	goto maybe_resched;
+
+	/* We found a keyring and we need to check the payload for links to
+	 * dead or expired keys.  We don't flag another reap immediately as we
+	 * have to wait for the old payload to be destroyed by RCU before we
+	 * can reap the keys to which it refers.
+	 */
+found_keyring:
+	spin_unlock(&key_serial_lock);
+	kdebug("scan keyring %d", key->serial);
+	key_gc_keyring(key, limit);
+	goto maybe_resched;
+
+	/* We found a dead key that is still referenced.  Reset its type and
+	 * destroy its payload with its semaphore held.
+	 */
+destroy_dead_key:
+	spin_unlock(&key_serial_lock);
+	kdebug("destroy key %d", key->serial);
+	down_write(&key->sem);
+	key->type = &key_type_dead;
+	if (key_gc_dead_keytype->destroy)
+		key_gc_dead_keytype->destroy(key);
+	memset(&key->payload, KEY_DESTROY, sizeof(key->payload));
+	up_write(&key->sem);
+	goto maybe_resched;
 }
diff --git a/security/keys/internal.h b/security/keys/internal.h
index a7cd1a6..c7a7cae 100644
--- a/security/keys/internal.h
+++ b/security/keys/internal.h
@@ -31,6 +31,7 @@
 	no_printk(KERN_DEBUG FMT"\n", ##__VA_ARGS__)
 #endif
 
+extern struct key_type key_type_dead;
 extern struct key_type key_type_user;
 
 /*****************************************************************************/
@@ -147,10 +148,11 @@
 
 extern long join_session_keyring(const char *name);
 
-extern struct work_struct key_gc_unused_work;
+extern struct work_struct key_gc_work;
 extern unsigned key_gc_delay;
 extern void keyring_gc(struct key *keyring, time_t limit);
 extern void key_schedule_gc(time_t expiry_at);
+extern void key_gc_keytype(struct key_type *ktype);
 
 extern int key_task_permission(const key_ref_t key_ref,
 			       const struct cred *cred,
diff --git a/security/keys/key.c b/security/keys/key.c
index 1f3ed44..4414abd 100644
--- a/security/keys/key.c
+++ b/security/keys/key.c
@@ -39,11 +39,6 @@
 /* We serialise key instantiation and link */
 DEFINE_MUTEX(key_construction_mutex);
 
-/* Any key who's type gets unegistered will be re-typed to this */
-static struct key_type key_type_dead = {
-	.name		= "dead",
-};
-
 #ifdef KEY_DEBUGGING
 void __key_check(const struct key *key)
 {
@@ -602,7 +597,7 @@
 		key_check(key);
 
 		if (atomic_dec_and_test(&key->usage))
-			queue_work(system_nrt_wq, &key_gc_unused_work);
+			queue_work(system_nrt_wq, &key_gc_work);
 	}
 }
 EXPORT_SYMBOL(key_put);
@@ -980,49 +975,11 @@
  */
 void unregister_key_type(struct key_type *ktype)
 {
-	struct rb_node *_n;
-	struct key *key;
-
 	down_write(&key_types_sem);
-
-	/* withdraw the key type */
 	list_del_init(&ktype->link);
-
-	/* mark all the keys of this type dead */
-	spin_lock(&key_serial_lock);
-
-	for (_n = rb_first(&key_serial_tree); _n; _n = rb_next(_n)) {
-		key = rb_entry(_n, struct key, serial_node);
-
-		if (key->type == ktype) {
-			key->type = &key_type_dead;
-			set_bit(KEY_FLAG_DEAD, &key->flags);
-		}
-	}
-
-	spin_unlock(&key_serial_lock);
-
-	/* make sure everyone revalidates their keys */
-	synchronize_rcu();
-
-	/* we should now be able to destroy the payloads of all the keys of
-	 * this type with impunity */
-	spin_lock(&key_serial_lock);
-
-	for (_n = rb_first(&key_serial_tree); _n; _n = rb_next(_n)) {
-		key = rb_entry(_n, struct key, serial_node);
-
-		if (key->type == ktype) {
-			if (ktype->destroy)
-				ktype->destroy(key);
-			memset(&key->payload, KEY_DESTROY, sizeof(key->payload));
-		}
-	}
-
-	spin_unlock(&key_serial_lock);
-	up_write(&key_types_sem);
-
-	key_schedule_gc(0);
+	downgrade_write(&key_types_sem);
+	key_gc_keytype(ktype);
+	up_read(&key_types_sem);
 }
 EXPORT_SYMBOL(unregister_key_type);