[JAX] Allow pallas to accept scalar shape semaphores.

PiperOrigin-RevId: 727198066
This commit is contained in:
Marcello Maggioni 2025-02-14 23:19:50 -08:00 committed by jax authors
parent df135d2f0b
commit 9a8c9a56cf
5 changed files with 75 additions and 1 deletions

View File

@ -40,10 +40,15 @@ py_library(
deps = [
"//jax",
"//jax:ad_util",
"//jax:api_util",
"//jax:config",
"//jax:core",
"//jax:dtypes",
"//jax:effects",
"//jax:mlir",
"//jax:partial_eval",
"//jax:pretty_printer",
"//jax:source_info_util",
"//jax:tree_util",
"//jax:util",
"//jax/_src/lib",

View File

@ -1191,3 +1191,7 @@ def lower_as_mlir(
stablehlo = lowered.compiler_ir(dialect="stablehlo")
return stablehlo # type: ignore[return-value]
_out_shape_to_aval_mapping: dict[
type[Any], Callable[[Any], jax_core.AbstractValue]
] = {}

View File

@ -46,6 +46,7 @@ ScratchShapeTree = pallas_core.ScratchShapeTree
AbstractMemoryRef = pallas_core.AbstractMemoryRef
no_block_spec = pallas_core.no_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
_out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping
split_list = util.split_list
_ENABLE_RUNTIME_ASSERT = config.bool_state(
@ -278,3 +279,14 @@ def _tensorcore_mesh_discharge_rule(
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
_tensorcore_mesh_discharge_rule
)
def _convert_semaphore_type_to_aval(
out_shape: SemaphoreType,
) -> jax_core.AbstractValue:
return out_shape.get_array_aval()
pallas_core._out_shape_to_aval_mapping[SemaphoreType] = (
_convert_semaphore_type_to_aval
)

View File

@ -38,9 +38,9 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import helpers as pallas_helpers
from jax._src.pallas import hlo_interpreter
from jax._src.pallas import primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.util import (
@ -1337,6 +1337,10 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
case pallas_core.MemoryRef():
return out_shape.get_array_aval()
case _:
if type(out_shape) in pallas_core._out_shape_to_aval_mapping:
return pallas_core._out_shape_to_aval_mapping[type(out_shape)](
out_shape
)
if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")):
raise ValueError(f"Invalid out_shape type: {type(out_shape)}")
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)

View File

@ -1188,6 +1188,55 @@ class PallasCallDMATest(PallasBaseTest):
)(x)
np.testing.assert_array_equal(y, x)
def test_output_dma_semaphore_ref(self):
if self.INTERPRET:
self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.')
def kernel(x_hbm_ref, y_hbm_ref, sem_out):
pltpu.make_async_copy(
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_out
).start()
def kernel2(x_hbm_ref, y_hbm_ref, sem_in, y_hbm_out):
del y_hbm_out
pltpu.make_async_copy(
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_in
).wait()
x = jnp.arange(8 * 128.0).reshape((8, 128))
@jax.jit
def body(x):
y, sem_out = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=[
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.float32),
pltpu.SemaphoreType.DMA,
],
)(x)
y = self.pallas_call(
kernel2,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
input_output_aliases={1: 0},
)(x, y, sem_out)
return y
np.testing.assert_array_equal(body(x), x)
def test_hbm_hbm_grid_dma(self):
# When using the grid, we have to emit Mosaic window_params. Test that they
# work correctly with ANY memory space operands.