[pallas:mosaic_gpu] Removed unnecessarily strict check in emit_pipeline

PiperOrigin-RevId: 703117465
This commit is contained in:
Sergei Lebedev 2024-12-05 08:03:37 -08:00 committed by jax authors
parent 5fe5206b6a
commit 4a41aa0a46

View File

@ -181,17 +181,6 @@ def emit_pipeline(
delay_release = 0 # No need to delay anything.
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
if any(
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
for idx in range(1, len(grid) + 1)
if spec.block_shape is not None
):
raise NotImplementedError(
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
f" shape {spec.block_shape}."
)
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
[