virtio: Force use of power-of-two for descriptor ring sizes

The virtio descriptor rings of size N-1 were nicely set up to be
aligned to an N-byte boundary.  But as Anthony Liguori points out, the
free-running indices used by virtio require that the sizes be a power
of 2, otherwise we get problems on wrap (demonstrated with lguest).

So we replace the clever "2^n-1" scheme with a simple "align to page
boundary" scheme: this means that all virtio rings take at least two
pages, but it's safer than guessing cache alignment.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
diff --git a/include/linux/virtio_ring.h b/include/linux/virtio_ring.h
index 5b88d21..1a4ed49 100644
--- a/include/linux/virtio_ring.h
+++ b/include/linux/virtio_ring.h
@@ -67,7 +67,7 @@
 };
 
 /* The standard layout for the ring is a continuous chunk of memory which looks
- * like this.  The used fields will be aligned to a "num+1" boundary.
+ * like this.  We assume num is a power of 2.
  *
  * struct vring
  * {
@@ -79,8 +79,8 @@
  *	__u16 avail_idx;
  *	__u16 available[num];
  *
- *	// Padding so a correctly-chosen num value will cache-align used_idx.
- *	char pad[sizeof(struct vring_desc) - sizeof(avail_flags)];
+ *	// Padding to the next page boundary.
+ *	char pad[];
  *
  *	// A ring of used descriptor heads with free-running index.
  *	__u16 used_flags;
@@ -88,18 +88,21 @@
  *	struct vring_used_elem used[num];
  * };
  */
-static inline void vring_init(struct vring *vr, unsigned int num, void *p)
+static inline void vring_init(struct vring *vr, unsigned int num, void *p,
+			      unsigned int pagesize)
 {
 	vr->num = num;
 	vr->desc = p;
 	vr->avail = p + num*sizeof(struct vring_desc);
-	vr->used = p + (num+1)*(sizeof(struct vring_desc) + sizeof(__u16));
+	vr->used = (void *)(((unsigned long)&vr->avail->ring[num] + pagesize-1)
+			    & ~(pagesize - 1));
 }
 
-static inline unsigned vring_size(unsigned int num)
+static inline unsigned vring_size(unsigned int num, unsigned int pagesize)
 {
-	return (num + 1) * (sizeof(struct vring_desc) + sizeof(__u16))
-		+ sizeof(__u32) + num * sizeof(struct vring_used_elem);
+	return ((sizeof(struct vring_desc) * num + sizeof(__u16) * (2 + num)
+		 + pagesize - 1) & ~(pagesize - 1))
+		+ sizeof(__u16) * 2 + sizeof(struct vring_used_elem) * num;
 }
 
 #ifdef __KERNEL__