Relax condition for partitioning dynamic grid dimensions over cores in pipeline emitter

PiperOrigin-RevId: 645240166
This commit is contained in:
Sharad Vikram 2024-06-20 20:14:12 -07:00 committed by jax authors
parent 787e747364
commit 1eb215eb87

View File

@ -800,9 +800,9 @@ def _partition_grid(
"Cannot partition over cores without parallel grid dimensions:"
f" {dimension_semantics=}"
)
if any(not isinstance(grid[i], int) for i in parallel_dimensions):
if all(not isinstance(grid[i], int) for i in parallel_dimensions):
raise NotImplementedError(
f"Cannot partition over cores with non-static grid dimensions: {grid=}"
f"Cannot partition cores over only dynamic grid dimensions: {grid=}"
)
# Try to find a divisible dimension to partition the grid on
divisible_dimensions = {