mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas] Use core_map instead of shard_map for Shmallas
- core_map is like a shard_map but it takes in no inputs and outputs - we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU) - we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target PiperOrigin-RevId: 686036101
This commit is contained in:
parent
b0768906db
commit
cd78c653e7
@ -243,7 +243,7 @@ def get_intermediate_shardings(
|
||||
out.extend((i, source_info) for i in eqn.params['in_shardings'])
|
||||
out.extend((o, source_info) for o in eqn.params['out_shardings'])
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
if not eqn.params['mesh']._is_jax_device_mesh:
|
||||
if isinstance(eqn.params['mesh'], AbstractMesh):
|
||||
continue
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
def _names_to_pspec(names):
|
||||
|
@ -270,11 +270,6 @@ class Mesh(contextlib.ContextDecorator):
|
||||
def _local_mesh(self, process_index):
|
||||
return _get_local_mesh(self, process_index)
|
||||
|
||||
@property
|
||||
def _is_jax_device_mesh(self):
|
||||
# Returns if the mesh contains JAX devices or not
|
||||
return True
|
||||
|
||||
@functools.cached_property
|
||||
def device_ids(self):
|
||||
assert not self.empty
|
||||
@ -377,10 +372,6 @@ class AbstractMesh:
|
||||
def shape(self):
|
||||
return collections.OrderedDict(self.shape_tuple)
|
||||
|
||||
@property
|
||||
def _is_jax_device_mesh(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def _internal_device_list(self):
|
||||
return None
|
||||
|
@ -31,7 +31,6 @@ from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
@ -1019,14 +1018,6 @@ def pytreedef_mismatch_err_msg(
|
||||
return "\n".join(msg)
|
||||
|
||||
|
||||
class PallasMesh(mesh_lib.Mesh):
|
||||
"""A specialized mesh used for lowering shard_map -> pallas_call."""
|
||||
|
||||
@property
|
||||
def _is_jax_device_mesh(self):
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CostEstimate:
|
||||
flops: int
|
||||
@ -1038,3 +1029,75 @@ class CostEstimate:
|
||||
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
|
||||
f' "bytes_accessed": {self.bytes_accessed}}}'
|
||||
).encode("ascii")
|
||||
|
||||
|
||||
core_map_p = jax_core.Primitive("core_map")
|
||||
core_map_p.multiple_results = True
|
||||
|
||||
def core_map(mesh):
|
||||
"""Runs a function on a mesh, mapping it over the devices in the mesh.
|
||||
|
||||
The function should be stateful in that it takes in no inputs and returns
|
||||
no outputs but can mutate closed-over Refs, for example.
|
||||
"""
|
||||
def wrapped(f):
|
||||
flat_args, in_tree = tree_util.tree_flatten(((), {}))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
with jax_core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
|
||||
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh)
|
||||
if out:
|
||||
raise ValueError("core_map-ped functions must not return any outputs.")
|
||||
return tree_util.tree_unflatten(out_tree_thunk(), out)
|
||||
return wrapped
|
||||
|
||||
|
||||
@core_map_p.def_effectful_abstract_eval
|
||||
def _core_map_abstract_eval(*args, jaxpr, mesh):
|
||||
del args
|
||||
if jaxpr.outvars:
|
||||
raise ValueError("core_map must not return any outputs.")
|
||||
effs = set()
|
||||
for eff in jaxpr.effects:
|
||||
if not isinstance(eff, jax_core.NamedAxisEffect):
|
||||
effs.add(eff)
|
||||
continue
|
||||
if eff.name not in mesh.shape:
|
||||
effs.add(eff)
|
||||
return [], effs
|
||||
|
||||
|
||||
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
|
||||
@state_discharge.register_discharge_rule(core_map_p)
|
||||
def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs):
|
||||
if type(mesh) not in _core_map_mesh_rules:
|
||||
raise NotImplementedError(f"Mesh type {type(mesh)} not supported.")
|
||||
return _core_map_mesh_rules[type(mesh)](
|
||||
in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh):
|
||||
del in_atoms
|
||||
with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())):
|
||||
jax_core.check_jaxpr(jaxpr)
|
||||
effs = set()
|
||||
for eff in jaxpr.effects:
|
||||
if not isinstance(eff, jax_core.NamedAxisEffect):
|
||||
effs.add(eff)
|
||||
continue
|
||||
if eff.name not in mesh.shape:
|
||||
effs.add(eff)
|
||||
return [], effs
|
||||
jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
|
||||
|
||||
|
||||
def _core_map_axis_subst(params, subst, traverse):
|
||||
if not traverse:
|
||||
return params
|
||||
def shadowed_subst(name):
|
||||
return (name,) if name in params['mesh'].shape else subst(name)
|
||||
with jax_core.extend_axis_env_nd(params['mesh'].shape.items()):
|
||||
new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
|
||||
return dict(params, jaxpr=new_jaxpr)
|
||||
jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Contains TPU-specific Pallas abstractions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
@ -27,6 +28,7 @@ from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -208,14 +210,69 @@ class TensorCore:
|
||||
id: int
|
||||
|
||||
|
||||
def create_tensorcore_mesh(axis_name: str) -> pallas_core.PallasMesh:
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TensorCoreMesh:
|
||||
"""A mesh of TensorCores."""
|
||||
devices: np.ndarray
|
||||
axis_names: Sequence[str]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
|
||||
|
||||
|
||||
def create_tensorcore_mesh(
|
||||
axis_name: str, devices: Sequence[jax.Device] | None = None
|
||||
) -> TensorCoreMesh:
|
||||
# TODO(b/355036384): emit a better error if we don't have tensorcores.
|
||||
num_cores = jax.devices()[0].num_cores
|
||||
return pallas_core.PallasMesh(
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
num_cores = devices[0].num_cores
|
||||
return TensorCoreMesh(
|
||||
np.array([TensorCore(i) for i in range(num_cores)]),
|
||||
[axis_name],
|
||||
)
|
||||
|
||||
|
||||
def runtime_assert_enabled() -> bool:
|
||||
"""Returns whether runtime asserts are enabled."""
|
||||
return _ENABLE_RUNTIME_ASSERT.value
|
||||
|
||||
|
||||
def _tensorcore_mesh_discharge_rule(
|
||||
in_avals,
|
||||
out_avals,
|
||||
*args,
|
||||
mesh,
|
||||
jaxpr,
|
||||
):
|
||||
del out_avals
|
||||
assert isinstance(mesh, TensorCoreMesh)
|
||||
if len(mesh.shape) > 1:
|
||||
raise NotImplementedError("Mesh must be 1D")
|
||||
core_axis_name, num_cores = list(mesh.shape.items())[0]
|
||||
def body(*args):
|
||||
# Due to aliasing, args contains aliased inputs and outputs so we remove
|
||||
# outputs.
|
||||
in_refs = args[:len(in_avals)]
|
||||
jax_core.eval_jaxpr(jaxpr, in_refs)
|
||||
assert len(jaxpr.outvars) == 0
|
||||
out = pallas_call.pallas_call(
|
||||
body,
|
||||
out_shape=in_avals,
|
||||
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
out_specs=[pallas_core.BlockSpec(
|
||||
memory_space=pallas_core.MemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
input_output_aliases={i: i for i in range(len(in_avals))},
|
||||
grid=((core_axis_name, num_cores),),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(dimension_semantics=("parallel",)),
|
||||
),
|
||||
)(*args)
|
||||
return out, ()
|
||||
|
||||
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
|
||||
_tensorcore_mesh_discharge_rule
|
||||
)
|
||||
|
@ -50,7 +50,6 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
@ -3078,54 +3077,3 @@ def _iota_2x32_shape_lowering(ctx, *, shape):
|
||||
|
||||
|
||||
lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering
|
||||
|
||||
# Lowering for shard_map
|
||||
|
||||
# Technically this is not a lowering rule, but a discharge rule. When we use
|
||||
# a special pallas mesh for a shard_map inside of a run_state, we turn it into
|
||||
# a pallas call. The pallas_call has named grid axes corresponding to the names
|
||||
# in the pallas mesh. It also sets up input/output aliasing automatically.
|
||||
|
||||
def _shard_map_discharge_rule(
|
||||
in_avals,
|
||||
out_avals,
|
||||
*args,
|
||||
mesh,
|
||||
auto,
|
||||
in_names,
|
||||
out_names,
|
||||
jaxpr,
|
||||
check_rep,
|
||||
rewrite,
|
||||
):
|
||||
del out_avals, auto, in_names, out_names, check_rep, rewrite
|
||||
if not isinstance(mesh, pallas_core.PallasMesh):
|
||||
raise NotImplementedError("Mesh must be a PallasMesh")
|
||||
if len(mesh.shape) > 1:
|
||||
raise NotImplementedError("Mesh must be 1D")
|
||||
core_axis_name, num_cores = list(mesh.shape.items())[0]
|
||||
def body(*args):
|
||||
in_refs = args[:len(in_avals)]
|
||||
jax_core.eval_jaxpr(jaxpr, (), *in_refs)
|
||||
assert len(jaxpr.outvars) == 0
|
||||
out = pallas_call.pallas_call(
|
||||
body,
|
||||
out_shape=in_avals,
|
||||
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
out_specs=[pallas_core.BlockSpec(
|
||||
memory_space=pallas_core.MemorySpace.ANY)]
|
||||
* len(in_avals),
|
||||
input_output_aliases={i: i for i in range(len(in_avals))},
|
||||
grid=((core_axis_name, num_cores),),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(dimension_semantics=("parallel",)),
|
||||
),
|
||||
)(*args)
|
||||
return out, ()
|
||||
|
||||
|
||||
from jax.experimental import shard_map
|
||||
state_discharge.register_discharge_rule(shard_map.shard_map_p)(
|
||||
_shard_map_discharge_rule
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
from jax._src.pallas.core import Blocked
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams
|
||||
from jax._src.pallas.core import core_map
|
||||
from jax._src.pallas.core import CostEstimate
|
||||
from jax._src.pallas.core import GridSpec
|
||||
from jax._src.pallas.core import IndexingMode
|
||||
@ -53,6 +54,7 @@ from jax._src.pallas.utils import cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape
|
||||
from jax._src.pallas.utils import when
|
||||
from jax._src.state.discharge import run_state
|
||||
from jax._src.state.indexing import ds
|
||||
from jax._src.state.indexing import dslice
|
||||
from jax._src.state.indexing import Slice
|
||||
|
@ -17,9 +17,7 @@ import functools
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -51,7 +49,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
_, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
@ -73,7 +71,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
_, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
@ -101,7 +99,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
|
||||
_, y = pl.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
@ -128,7 +126,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y, o = state_discharge.run_state(f_stateful)(
|
||||
_, y, o = pl.run_state(f_stateful)(
|
||||
(x, jnp.zeros_like(x), jnp.zeros_like(x))
|
||||
)
|
||||
return y, o
|
||||
@ -178,7 +176,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
|
||||
)()
|
||||
|
||||
_, _, o = state_discharge.run_state(run_matmul)(
|
||||
_, _, o = pl.run_state(run_matmul)(
|
||||
(x, y, jnp.ones((m, n), dtype=x.dtype))
|
||||
)
|
||||
return o
|
||||
@ -202,11 +200,7 @@ class ShmallasTest(jtu.JaxTestCase):
|
||||
def test_can_create_tensorcore_mesh(self):
|
||||
_ = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
def test_can_trivially_shard_map_with_pallas_mesh(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
|
||||
|
||||
def test_can_run_basic_pallas_kernel_with_shard_map(self):
|
||||
def test_can_run_basic_pallas_kernel_with_core_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
@ -214,19 +208,18 @@ class ShmallasTest(jtu.JaxTestCase):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
@pl.core_map(mesh)
|
||||
def _():
|
||||
def alloc(sem):
|
||||
pltpu.async_copy(x_ref, y_ref, sem).wait()
|
||||
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
_, y = pl.run_state(inner)((x, y))
|
||||
return y
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
|
||||
def test_can_query_core_index_pallas_kernel_with_core_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
@ -234,7 +227,8 @@ class ShmallasTest(jtu.JaxTestCase):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
@pl.core_map(mesh)
|
||||
def _():
|
||||
num_cores = jax.lax.psum(1, "x")
|
||||
slc_size = 16 // num_cores
|
||||
def alloc(x_vmem_ref, y_vmem_ref, sem):
|
||||
@ -254,9 +248,7 @@ class ShmallasTest(jtu.JaxTestCase):
|
||||
pltpu.VMEM((slc_size, 128), y_ref.dtype),
|
||||
pltpu.SemaphoreType.DMA,
|
||||
)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
_, y = pl.run_state(inner)((x, y))
|
||||
return y
|
||||
num_cores = jax.devices()[0].num_cores
|
||||
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))
|
||||
|
Loading…
x
Reference in New Issue
Block a user