mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Allow pallas to accept scalar shape semaphores.
PiperOrigin-RevId: 727198066
This commit is contained in:
parent
df135d2f0b
commit
9a8c9a56cf
@ -40,10 +40,15 @@ py_library(
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
"//jax:api_util",
|
||||
"//jax:config",
|
||||
"//jax:core",
|
||||
"//jax:dtypes",
|
||||
"//jax:effects",
|
||||
"//jax:mlir",
|
||||
"//jax:partial_eval",
|
||||
"//jax:pretty_printer",
|
||||
"//jax:source_info_util",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
|
@ -1191,3 +1191,7 @@ def lower_as_mlir(
|
||||
stablehlo = lowered.compiler_ir(dialect="stablehlo")
|
||||
|
||||
return stablehlo # type: ignore[return-value]
|
||||
|
||||
_out_shape_to_aval_mapping: dict[
|
||||
type[Any], Callable[[Any], jax_core.AbstractValue]
|
||||
] = {}
|
||||
|
@ -46,6 +46,7 @@ ScratchShapeTree = pallas_core.ScratchShapeTree
|
||||
AbstractMemoryRef = pallas_core.AbstractMemoryRef
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
|
||||
_out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping
|
||||
split_list = util.split_list
|
||||
|
||||
_ENABLE_RUNTIME_ASSERT = config.bool_state(
|
||||
@ -278,3 +279,14 @@ def _tensorcore_mesh_discharge_rule(
|
||||
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
|
||||
_tensorcore_mesh_discharge_rule
|
||||
)
|
||||
|
||||
|
||||
def _convert_semaphore_type_to_aval(
|
||||
out_shape: SemaphoreType,
|
||||
) -> jax_core.AbstractValue:
|
||||
return out_shape.get_array_aval()
|
||||
|
||||
|
||||
pallas_core._out_shape_to_aval_mapping[SemaphoreType] = (
|
||||
_convert_semaphore_type_to_aval
|
||||
)
|
||||
|
@ -38,9 +38,9 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import helpers as pallas_helpers
|
||||
from jax._src.pallas import hlo_interpreter
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.util import (
|
||||
@ -1337,6 +1337,10 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
|
||||
case pallas_core.MemoryRef():
|
||||
return out_shape.get_array_aval()
|
||||
case _:
|
||||
if type(out_shape) in pallas_core._out_shape_to_aval_mapping:
|
||||
return pallas_core._out_shape_to_aval_mapping[type(out_shape)](
|
||||
out_shape
|
||||
)
|
||||
if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")):
|
||||
raise ValueError(f"Invalid out_shape type: {type(out_shape)}")
|
||||
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
|
||||
|
@ -1188,6 +1188,55 @@ class PallasCallDMATest(PallasBaseTest):
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_output_dma_semaphore_ref(self):
|
||||
if self.INTERPRET:
|
||||
self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.')
|
||||
|
||||
def kernel(x_hbm_ref, y_hbm_ref, sem_out):
|
||||
pltpu.make_async_copy(
|
||||
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_out
|
||||
).start()
|
||||
|
||||
def kernel2(x_hbm_ref, y_hbm_ref, sem_in, y_hbm_out):
|
||||
del y_hbm_out
|
||||
pltpu.make_async_copy(
|
||||
x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_in
|
||||
).wait()
|
||||
|
||||
x = jnp.arange(8 * 128.0).reshape((8, 128))
|
||||
|
||||
@jax.jit
|
||||
def body(x):
|
||||
y, sem_out = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
],
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
pltpu.SemaphoreType.DMA,
|
||||
],
|
||||
)(x)
|
||||
|
||||
y = self.pallas_call(
|
||||
kernel2,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
],
|
||||
out_specs=pl.BlockSpec(memory_space=pl.ANY),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
input_output_aliases={1: 0},
|
||||
)(x, y, sem_out)
|
||||
return y
|
||||
|
||||
np.testing.assert_array_equal(body(x), x)
|
||||
|
||||
def test_hbm_hbm_grid_dma(self):
|
||||
# When using the grid, we have to emit Mosaic window_params. Test that they
|
||||
# work correctly with ANY memory space operands.
|
||||
|
Loading…
x
Reference in New Issue
Block a user