mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:mosaic_gpu] Removed unnecessarily strict check in emit_pipeline
PiperOrigin-RevId: 703117465
This commit is contained in:
parent
5fe5206b6a
commit
4a41aa0a46
@ -181,17 +181,6 @@ def emit_pipeline(
|
||||
delay_release = 0 # No need to delay anything.
|
||||
|
||||
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
|
||||
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
|
||||
if any(
|
||||
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
|
||||
for idx in range(1, len(grid) + 1)
|
||||
if spec.block_shape is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
|
||||
f" shape {spec.block_shape}."
|
||||
)
|
||||
|
||||
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
|
||||
in_smem_refs, out_smem_refs = util.split_list(
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user