dm cache: revert "remove remainder of distinct discard block size"

This reverts commit 64ab346a360a4b15c28fb8531918d4a01f4eabd9 because we
actually do want to allow the discard blocksize to be larger than the
cache blocksize.  Further dm-cache discard changes will make this
possible.

Signed-off-by: Joe Thornber <ejt@redhat.com>
Signed-off-by: Mike Snitzer <snitzer@redhat.com>
diff --git a/drivers/md/dm-cache-target.c b/drivers/md/dm-cache-target.c
index 890e2ff..ced7fd4 100644
--- a/drivers/md/dm-cache-target.c
+++ b/drivers/md/dm-cache-target.c
@@ -236,8 +236,9 @@
 	/*
 	 * origin_blocks entries, discarded if set.
 	 */
-	dm_oblock_t discard_nr_blocks;
+	dm_dblock_t discard_nr_blocks;
 	unsigned long *discard_bitset;
+	uint32_t discard_block_size;
 
 	/*
 	 * Rather than reconstructing the table line for the status we just
@@ -524,33 +525,48 @@
 	return b;
 }
 
-static void set_discard(struct cache *cache, dm_oblock_t b)
+static dm_dblock_t oblock_to_dblock(struct cache *cache, dm_oblock_t oblock)
+{
+	uint32_t discard_blocks = cache->discard_block_size;
+	dm_block_t b = from_oblock(oblock);
+
+	if (!block_size_is_power_of_two(cache))
+		discard_blocks = discard_blocks / cache->sectors_per_block;
+	else
+		discard_blocks >>= cache->sectors_per_block_shift;
+
+	b = block_div(b, discard_blocks);
+
+	return to_dblock(b);
+}
+
+static void set_discard(struct cache *cache, dm_dblock_t b)
 {
 	unsigned long flags;
 
 	atomic_inc(&cache->stats.discard_count);
 
 	spin_lock_irqsave(&cache->lock, flags);
-	set_bit(from_oblock(b), cache->discard_bitset);
+	set_bit(from_dblock(b), cache->discard_bitset);
 	spin_unlock_irqrestore(&cache->lock, flags);
 }
 
-static void clear_discard(struct cache *cache, dm_oblock_t b)
+static void clear_discard(struct cache *cache, dm_dblock_t b)
 {
 	unsigned long flags;
 
 	spin_lock_irqsave(&cache->lock, flags);
-	clear_bit(from_oblock(b), cache->discard_bitset);
+	clear_bit(from_dblock(b), cache->discard_bitset);
 	spin_unlock_irqrestore(&cache->lock, flags);
 }
 
-static bool is_discarded(struct cache *cache, dm_oblock_t b)
+static bool is_discarded(struct cache *cache, dm_dblock_t b)
 {
 	int r;
 	unsigned long flags;
 
 	spin_lock_irqsave(&cache->lock, flags);
-	r = test_bit(from_oblock(b), cache->discard_bitset);
+	r = test_bit(from_dblock(b), cache->discard_bitset);
 	spin_unlock_irqrestore(&cache->lock, flags);
 
 	return r;
@@ -562,7 +578,8 @@
 	unsigned long flags;
 
 	spin_lock_irqsave(&cache->lock, flags);
-	r = test_bit(from_oblock(b), cache->discard_bitset);
+	r = test_bit(from_dblock(oblock_to_dblock(cache, b)),
+		     cache->discard_bitset);
 	spin_unlock_irqrestore(&cache->lock, flags);
 
 	return r;
@@ -687,7 +704,7 @@
 	check_if_tick_bio_needed(cache, bio);
 	remap_to_origin(cache, bio);
 	if (bio_data_dir(bio) == WRITE)
-		clear_discard(cache, oblock);
+		clear_discard(cache, oblock_to_dblock(cache, oblock));
 }
 
 static void remap_to_cache_dirty(struct cache *cache, struct bio *bio,
@@ -697,7 +714,7 @@
 	remap_to_cache(cache, bio, cblock);
 	if (bio_data_dir(bio) == WRITE) {
 		set_dirty(cache, oblock, cblock);
-		clear_discard(cache, oblock);
+		clear_discard(cache, oblock_to_dblock(cache, oblock));
 	}
 }
 
@@ -1301,14 +1318,14 @@
 static void process_discard_bio(struct cache *cache, struct bio *bio)
 {
 	dm_block_t start_block = dm_sector_div_up(bio->bi_iter.bi_sector,
-						  cache->sectors_per_block);
+						  cache->discard_block_size);
 	dm_block_t end_block = bio_end_sector(bio);
 	dm_block_t b;
 
-	end_block = block_div(end_block, cache->sectors_per_block);
+	end_block = block_div(end_block, cache->discard_block_size);
 
 	for (b = start_block; b < end_block; b++)
-		set_discard(cache, to_oblock(b));
+		set_discard(cache, to_dblock(b));
 
 	bio_endio(bio, 0);
 }
@@ -2303,13 +2320,14 @@
 	}
 	clear_bitset(cache->dirty_bitset, from_cblock(cache->cache_size));
 
-	cache->discard_nr_blocks = cache->origin_blocks;
-	cache->discard_bitset = alloc_bitset(from_oblock(cache->discard_nr_blocks));
+	cache->discard_block_size = cache->sectors_per_block;
+	cache->discard_nr_blocks = oblock_to_dblock(cache, cache->origin_blocks);
+	cache->discard_bitset = alloc_bitset(from_dblock(cache->discard_nr_blocks));
 	if (!cache->discard_bitset) {
 		*error = "could not allocate discard bitset";
 		goto bad;
 	}
-	clear_bitset(cache->discard_bitset, from_oblock(cache->discard_nr_blocks));
+	clear_bitset(cache->discard_bitset, from_dblock(cache->discard_nr_blocks));
 
 	cache->copier = dm_kcopyd_client_create(&dm_kcopyd_throttle);
 	if (IS_ERR(cache->copier)) {
@@ -2599,16 +2617,16 @@
 {
 	unsigned i, r;
 
-	r = dm_cache_discard_bitset_resize(cache->cmd, cache->sectors_per_block,
-					   cache->origin_blocks);
+	r = dm_cache_discard_bitset_resize(cache->cmd, cache->discard_block_size,
+					   cache->discard_nr_blocks);
 	if (r) {
 		DMERR("could not resize on-disk discard bitset");
 		return r;
 	}
 
-	for (i = 0; i < from_oblock(cache->discard_nr_blocks); i++) {
-		r = dm_cache_set_discard(cache->cmd, to_oblock(i),
-					 is_discarded(cache, to_oblock(i)));
+	for (i = 0; i < from_dblock(cache->discard_nr_blocks); i++) {
+		r = dm_cache_set_discard(cache->cmd, to_dblock(i),
+					 is_discarded(cache, to_dblock(i)));
 		if (r)
 			return r;
 	}
@@ -2681,14 +2699,16 @@
 }
 
 static int load_discard(void *context, sector_t discard_block_size,
-			dm_oblock_t oblock, bool discard)
+			dm_dblock_t dblock, bool discard)
 {
 	struct cache *cache = context;
 
+	/* FIXME: handle mis-matched block size */
+
 	if (discard)
-		set_discard(cache, oblock);
+		set_discard(cache, dblock);
 	else
-		clear_discard(cache, oblock);
+		clear_discard(cache, dblock);
 
 	return 0;
 }
@@ -3079,8 +3099,8 @@
 	/*
 	 * FIXME: these limits may be incompatible with the cache device
 	 */
-	limits->max_discard_sectors = cache->sectors_per_block;
-	limits->discard_granularity = cache->sectors_per_block << SECTOR_SHIFT;
+	limits->max_discard_sectors = cache->discard_block_size;
+	limits->discard_granularity = cache->discard_block_size << SECTOR_SHIFT;
 }
 
 static void cache_io_hints(struct dm_target *ti, struct queue_limits *limits)