mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Relax condition for partitioning dynamic grid dimensions over cores in pipeline emitter
PiperOrigin-RevId: 645240166
This commit is contained in:
parent
787e747364
commit
1eb215eb87
@ -800,9 +800,9 @@ def _partition_grid(
|
||||
"Cannot partition over cores without parallel grid dimensions:"
|
||||
f" {dimension_semantics=}"
|
||||
)
|
||||
if any(not isinstance(grid[i], int) for i in parallel_dimensions):
|
||||
if all(not isinstance(grid[i], int) for i in parallel_dimensions):
|
||||
raise NotImplementedError(
|
||||
f"Cannot partition over cores with non-static grid dimensions: {grid=}"
|
||||
f"Cannot partition cores over only dynamic grid dimensions: {grid=}"
|
||||
)
|
||||
# Try to find a divisible dimension to partition the grid on
|
||||
divisible_dimensions = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user