diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py
index e4d2e2855..3021b6a16 100644
--- a/jax/_src/blocked_sampler.py
+++ b/jax/_src/blocked_sampler.py
@@ -29,8 +29,8 @@ class SampleFn(Protocol):
 
 
 def _compute_tile_index(block_index: Sequence[int],
-                        total_size_in_blocks: Shape,
                         block_size_in_tiles: Shape,
+                        total_size_in_tiles: Shape,
                         tile_index_in_block: Sequence[int]) -> int:
   ndims = len(block_index)
   dim_size = 1
@@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
   for i in range(ndims-1, -1, -1):
     dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
     total_idx += dim_idx * dim_size
-    dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
+    dim_size *= total_size_in_tiles[i]
   return total_idx
 
 
@@ -103,15 +103,17 @@ def blocked_fold_in(
       _shape // _element for _shape, _element in zip(block_size, tile_size)
   )
 
-  total_size_in_blocks = tuple(
-      _shape // _element for _shape, _element in zip(total_size, block_size)
+  # Round up to make sure every tile is numbered.
+  total_size_in_tiles = tuple(
+      (_shape + _element - 1) // _element
+        for _shape, _element in zip(total_size, tile_size)
   )
 
   def _keygen_loop(axis, prefix):
     if axis == len(block_size_in_tiles):
       subtile_key = jax.random.fold_in(
           global_key, _compute_tile_index(
-              block_index, total_size_in_blocks, block_size_in_tiles, prefix))
+              block_index, block_size_in_tiles, total_size_in_tiles, prefix))
       return subtile_key
     else:
       keys = []
diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py
index 4c27e850c..b5f87fe05 100644
--- a/tests/blocked_sampler_test.py
+++ b/tests/blocked_sampler_test.py
@@ -29,16 +29,23 @@ def call_kernel(
   kernel,
   grid: tuple[int, int],
   transpose_grid: bool,
-  *args
+  key: jax.Array,
+  total_size: tuple[int, int],
+  block_size: tuple[int, int],
+  tile_size: tuple[int, int],
   ):
   """Calls a kernel over a grid and concatenates results to a single array."""
   if transpose_grid:
     grid = (grid[1], grid[0])
   m, n = grid
-  return jnp.concatenate([
-      jnp.concatenate([
-          kernel((i, j), *args) for j in range(n)], axis=1)
-      for i in range(m)], axis=0)
+  samples = jnp.concatenate([
+              jnp.concatenate([
+                kernel((i, j), key, total_size, block_size, tile_size)
+                  for j in range(n)], axis=1)
+                    for i in range(m)], axis=0)
+  # Slice out the padding.
+  samples = samples[:total_size[0], :total_size[1]]
+  return samples
 
 
 def call_kernel_3d(
@@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
                          block_size=block_size,
                          tile_size=tile_size)
   return blocked_sampler.sample_block(jax.random.uniform,
-                                      keys,
-                                      block_size=block_size,
-                                      tile_size=tile_size,
-                                      minval=0.0, maxval=1.0)
+                                         keys,
+                                         block_size=block_size,
+                                         tile_size=tile_size,
+                                         minval=0.0, maxval=1.0)
 
 
 class BlockedSamplerTest(jtu.JaxTestCase):
@@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
     dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
          block_size_a=(16, 256), block_size_b=(32, 128),
          tile_size=(8, 128), transpose_grid=False),
+    dict(testcase_name='128x128_vs_128x256_padding',
+         total_size=(256, 128), block_size_a=(128, 128),
+         block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
+    dict(testcase_name='128x128_vs_128x256_padding2',
+         total_size=(257, 129), block_size_a=(128, 128),
+         block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
   )
   def test_block_shape_invariance(self, total_size, block_size_a,
                           block_size_b, tile_size, transpose_grid):
     global_key = jax.random.key(0)
-    grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
+    ceil_div = lambda x, y: (x + y - 1) // y
+    grid_a = tuple(ceil_div(_tot, _blk)
+                   for _tot, _blk in zip(total_size, block_size_a))
     result_a = call_kernel(
         uniform_kernel, grid_a, transpose_grid, global_key,
         total_size, block_size_a, tile_size)
 
-    grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
+    grid_b = tuple(ceil_div(_tot, _blk)
+                   for _tot, _blk in zip(total_size, block_size_b))
     result_b = call_kernel(
         uniform_kernel, grid_b, transpose_grid, global_key,
         total_size, block_size_b, tile_size)