[IPV4]: More broken memory allocation fixes for fib_trie

Below a patch to preallocate memory when doing resize of trie (inflate halve)
If preallocations fails it just skips the resize of this tnode for this time.

The oops we got when killing bgpd (with full routing) is now gone. 
Patrick memory patch is also used.

Signed-off-by: Robert Olsson <robert.olsson@its.uu.se>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c
index 9038b91..4be234c 100644
--- a/net/ipv4/fib_trie.c
+++ b/net/ipv4/fib_trie.c
@@ -43,7 +43,7 @@
  *		2 of the License, or (at your option) any later version.
  */
 
-#define VERSION "0.324"
+#define VERSION "0.325"
 
 #include <linux/config.h>
 #include <asm/uaccess.h>
@@ -136,6 +136,7 @@
 	unsigned int semantic_match_passed;
 	unsigned int semantic_match_miss;
 	unsigned int null_node_hit;
+	unsigned int resize_node_skipped;
 };
 #endif
 
@@ -164,8 +165,8 @@
 static void tnode_put_child_reorg(struct tnode *tn, int i, struct node *n, int wasfull);
 static int tnode_child_length(struct tnode *tn);
 static struct node *resize(struct trie *t, struct tnode *tn);
-static struct tnode *inflate(struct trie *t, struct tnode *tn);
-static struct tnode *halve(struct trie *t, struct tnode *tn);
+static struct tnode *inflate(struct trie *t, struct tnode *tn, int *err);
+static struct tnode *halve(struct trie *t, struct tnode *tn, int *err);
 static void tnode_free(struct tnode *tn);
 static void trie_dump_seq(struct seq_file *seq, struct trie *t);
 extern struct fib_alias *fib_find_alias(struct list_head *fah, u8 tos, u32 prio);
@@ -481,6 +482,7 @@
 static struct node *resize(struct trie *t, struct tnode *tn) 
 {
 	int i;
+	int err = 0;
 
  	if (!tn)
 		return NULL;
@@ -577,12 +579,20 @@
 	 */
 
 	check_tnode(tn);
-
+	
+	err = 0;
 	while ((tn->full_children > 0 &&
 	       50 * (tn->full_children + tnode_child_length(tn) - tn->empty_children) >=
 				inflate_threshold * tnode_child_length(tn))) {
 
-		tn = inflate(t, tn);
+		tn = inflate(t, tn, &err);
+
+		if(err) {
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+			t->stats.resize_node_skipped++;
+#endif
+			break;
+		}
 	}
 
 	check_tnode(tn);
@@ -591,11 +601,22 @@
 	 * Halve as long as the number of empty children in this
 	 * node is above threshold.
 	 */
+
+	err = 0;
 	while (tn->bits > 1 &&
 	       100 * (tnode_child_length(tn) - tn->empty_children) <
-	       halve_threshold * tnode_child_length(tn))
+	       halve_threshold * tnode_child_length(tn)) {
 
-		tn = halve(t, tn);
+		tn = halve(t, tn, &err);
+
+		if(err) {
+#ifdef CONFIG_IP_FIB_TRIE_STATS
+			t->stats.resize_node_skipped++;
+#endif
+			break;
+		}
+	}
+
   
 	/* Only one child remains */
 
@@ -620,7 +641,7 @@
 	return (struct node *) tn;
 }
 
-static struct tnode *inflate(struct trie *t, struct tnode *tn)
+static struct tnode *inflate(struct trie *t, struct tnode *tn, int *err)
 {
 	struct tnode *inode;
 	struct tnode *oldtnode = tn;
@@ -632,8 +653,63 @@
 
 	tn = tnode_new(oldtnode->key, oldtnode->pos, oldtnode->bits + 1);
 
-	if (!tn)
-		trie_bug("tnode_new failed");
+	if (!tn) {
+		*err = -ENOMEM;
+		return oldtnode;
+	}
+
+	/*
+	 * Preallocate and store tnodes before the actual work so we 
+	 * don't get into an inconsistent state if memory allocation 
+	 * fails. In case of failure we return the oldnode and  inflate 
+	 * of tnode is ignored.
+	 */
+			
+	for(i = 0; i < olen; i++) {
+		struct tnode *inode = (struct tnode *) tnode_get_child(oldtnode, i);
+
+		if (inode &&
+		    IS_TNODE(inode) &&
+		    inode->pos == oldtnode->pos + oldtnode->bits &&
+		    inode->bits > 1) {
+			struct tnode *left, *right;
+
+			t_key m = TKEY_GET_MASK(inode->pos, 1);
+ 
+			left = tnode_new(inode->key&(~m), inode->pos + 1,
+					 inode->bits - 1);
+
+			if(!left) {
+				*err = -ENOMEM; 
+				break;
+			}
+			
+			right = tnode_new(inode->key|m, inode->pos + 1,
+					  inode->bits - 1);
+
+			if(!right) {
+				*err = -ENOMEM; 
+				break;
+			}
+
+			put_child(t, tn, 2*i, (struct node *) left);
+			put_child(t, tn, 2*i+1, (struct node *) right);
+		}
+	}
+
+	if(*err) {
+		int size = tnode_child_length(tn);
+		int j;
+
+		for(j = 0; j < size; j++) 
+			if( tn->child[j])
+				tnode_free((struct tnode *)tn->child[j]);
+
+		tnode_free(tn);
+		
+		*err = -ENOMEM;
+		return oldtnode;
+	}
 
 	for(i = 0; i < olen; i++) {
 		struct node *node = tnode_get_child(oldtnode, i);
@@ -646,7 +722,7 @@
 
 		if(IS_LEAF(node) || ((struct tnode *) node)->pos >
 		   tn->pos + tn->bits - 1) {
-			if(tkey_extract_bits(node->key, tn->pos + tn->bits - 1,
+			if(tkey_extract_bits(node->key, oldtnode->pos + oldtnode->bits,
 					     1) == 0)
 				put_child(t, tn, 2*i, node);
 			else
@@ -686,27 +762,22 @@
 			 * the position (inode->pos)
 			 */
 
-			t_key m = TKEY_GET_MASK(inode->pos, 1);
- 
 			/* Use the old key, but set the new significant 
 			 *   bit to zero. 
 			 */
-			left = tnode_new(inode->key&(~m), inode->pos + 1,
-					 inode->bits - 1);
 
-			if(!left) 
-				trie_bug("tnode_new failed");
-			
-			
-			/* Use the old key, but set the new significant 
-			 * bit to one. 
-			 */
-			right = tnode_new(inode->key|m, inode->pos + 1,
-					  inode->bits - 1);
+			left = (struct tnode *) tnode_get_child(tn, 2*i);
+			put_child(t, tn, 2*i, NULL);
 
-			if(!right) 
-				trie_bug("tnode_new failed");
-			
+			if(!left)
+				BUG();
+
+			right = (struct tnode *) tnode_get_child(tn, 2*i+1);
+			put_child(t, tn, 2*i+1, NULL);
+
+			if(!right)
+				BUG();
+
 			size = tnode_child_length(left);
 			for(j = 0; j < size; j++) {
 				put_child(t, left, j, inode->child[j]);
@@ -722,7 +793,7 @@
 	return tn;
 }
 
-static struct tnode *halve(struct trie *t, struct tnode *tn)
+static struct tnode *halve(struct trie *t, struct tnode *tn, int *err)
 {
 	struct tnode *oldtnode = tn;
 	struct node *left, *right;
@@ -733,8 +804,48 @@
   
 	tn=tnode_new(oldtnode->key, oldtnode->pos, oldtnode->bits - 1);
 
-	if(!tn) 
-		trie_bug("tnode_new failed");
+	if (!tn) {
+		*err = -ENOMEM;
+		return oldtnode;
+	}
+
+	/*
+	 * Preallocate and store tnodes before the actual work so we 
+	 * don't get into an inconsistent state if memory allocation 
+	 * fails. In case of failure we return the oldnode and halve 
+	 * of tnode is ignored.
+	 */
+
+	for(i = 0; i < olen; i += 2) {
+		left = tnode_get_child(oldtnode, i);
+		right = tnode_get_child(oldtnode, i+1);
+    
+		/* Two nonempty children */
+		if( left && right)  {
+			struct tnode *newBinNode =
+				tnode_new(left->key, tn->pos + tn->bits, 1);
+
+			if(!newBinNode) {
+				*err = -ENOMEM; 
+				break;
+			}
+			put_child(t, tn, i/2, (struct node *)newBinNode);
+		}
+	}
+
+	if(*err) {
+		int size = tnode_child_length(tn);
+		int j;
+
+		for(j = 0; j < size; j++) 
+			if( tn->child[j])
+				tnode_free((struct tnode *)tn->child[j]);
+
+		tnode_free(tn);
+		
+		*err = -ENOMEM;
+		return oldtnode;
+	}
 
 	for(i = 0; i < olen; i += 2) {
 		left = tnode_get_child(oldtnode, i);
@@ -751,10 +862,11 @@
 		/* Two nonempty children */
 		else {
 			struct tnode *newBinNode =
-				tnode_new(left->key, tn->pos + tn->bits, 1);
+				(struct tnode *) tnode_get_child(tn, i/2);
+			put_child(t, tn, i/2, NULL);
 
 			if(!newBinNode) 
-				trie_bug("tnode_new failed");
+				BUG();
 
 			put_child(t, newBinNode, 0, left);
 			put_child(t, newBinNode, 1, right);
@@ -2322,6 +2434,7 @@
 	seq_printf(seq,"semantic match passed = %d\n", t->stats.semantic_match_passed);
 	seq_printf(seq,"semantic match miss = %d\n", t->stats.semantic_match_miss);
 	seq_printf(seq,"null node hit= %d\n", t->stats.null_node_hit);
+	seq_printf(seq,"skipped node resize = %d\n", t->stats.resize_node_skipped);
 #ifdef CLEAR_STATS
 	memset(&(t->stats), 0, sizeof(t->stats));
 #endif