diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 35bb9bab1..acffbbbef 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2197,9 +2197,6 @@ class ShardMapTest(jtu.JaxTestCase): # f(x) # don't crash def test_partial_auto_of_random_keys(self): - if config.use_shardy_partitioner.value: - self.skipTest('Shardy does not support full-to-shard.') - mesh = jtu.create_mesh((4, 2), ('i', 'j')) keys = jax.random.split(jax.random.key(0), 8)