[Pallas] Use core_map instead of shard_map for Shmallas

- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
This commit is contained in:
Sharad Vikram 2024-10-15 03:26:28 -07:00 committed by jax authors
parent b0768906db
commit cd78c653e7
7 changed files with 148 additions and 95 deletions

View File

@ -243,7 +243,7 @@ def get_intermediate_shardings(
out.extend((i, source_info) for i in eqn.params['in_shardings'])
out.extend((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
if not eqn.params['mesh']._is_jax_device_mesh:
if isinstance(eqn.params['mesh'], AbstractMesh):
continue
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
def _names_to_pspec(names):

View File

@ -270,11 +270,6 @@ class Mesh(contextlib.ContextDecorator):
def _local_mesh(self, process_index):
return _get_local_mesh(self, process_index)
@property
def _is_jax_device_mesh(self):
# Returns if the mesh contains JAX devices or not
return True
@functools.cached_property
def device_ids(self):
assert not self.empty
@ -377,10 +372,6 @@ class AbstractMesh:
def shape(self):
return collections.OrderedDict(self.shape_tuple)
@property
def _is_jax_device_mesh(self):
return False
@property
def _internal_device_list(self):
return None

View File

@ -31,7 +31,6 @@ from jax._src import config
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import state
from jax._src import tree_util
from jax._src import util
@ -1019,14 +1018,6 @@ def pytreedef_mismatch_err_msg(
return "\n".join(msg)
class PallasMesh(mesh_lib.Mesh):
"""A specialized mesh used for lowering shard_map -> pallas_call."""
@property
def _is_jax_device_mesh(self):
return False
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
@ -1038,3 +1029,75 @@ class CostEstimate:
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode("ascii")
core_map_p = jax_core.Primitive("core_map")
core_map_p.multiple_results = True
def core_map(mesh):
"""Runs a function on a mesh, mapping it over the devices in the mesh.
The function should be stateful in that it takes in no inputs and returns
no outputs but can mutate closed-over Refs, for example.
"""
def wrapped(f):
flat_args, in_tree = tree_util.tree_flatten(((), {}))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
with jax_core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh)
if out:
raise ValueError("core_map-ped functions must not return any outputs.")
return tree_util.tree_unflatten(out_tree_thunk(), out)
return wrapped
@core_map_p.def_effectful_abstract_eval
def _core_map_abstract_eval(*args, jaxpr, mesh):
del args
if jaxpr.outvars:
raise ValueError("core_map must not return any outputs.")
effs = set()
for eff in jaxpr.effects:
if not isinstance(eff, jax_core.NamedAxisEffect):
effs.add(eff)
continue
if eff.name not in mesh.shape:
effs.add(eff)
return [], effs
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
@state_discharge.register_discharge_rule(core_map_p)
def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs):
if type(mesh) not in _core_map_mesh_rules:
raise NotImplementedError(f"Mesh type {type(mesh)} not supported.")
return _core_map_mesh_rules[type(mesh)](
in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs
)
def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh):
del in_atoms
with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())):
jax_core.check_jaxpr(jaxpr)
effs = set()
for eff in jaxpr.effects:
if not isinstance(eff, jax_core.NamedAxisEffect):
effs.add(eff)
continue
if eff.name not in mesh.shape:
effs.add(eff)
return [], effs
jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
def _core_map_axis_subst(params, subst, traverse):
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with jax_core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst

View File

@ -15,6 +15,7 @@
"""Contains TPU-specific Pallas abstractions."""
from __future__ import annotations
import collections
from collections.abc import Sequence
import dataclasses
import enum
@ -27,6 +28,7 @@ from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
import jax.numpy as jnp
import numpy as np
@ -208,14 +210,69 @@ class TensorCore:
id: int
def create_tensorcore_mesh(axis_name: str) -> pallas_core.PallasMesh:
@dataclasses.dataclass(frozen=True)
class TensorCoreMesh:
"""A mesh of TensorCores."""
devices: np.ndarray
axis_names: Sequence[str]
@property
def shape(self):
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
def create_tensorcore_mesh(
axis_name: str, devices: Sequence[jax.Device] | None = None
) -> TensorCoreMesh:
# TODO(b/355036384): emit a better error if we don't have tensorcores.
num_cores = jax.devices()[0].num_cores
return pallas_core.PallasMesh(
if devices is None:
devices = jax.devices()
num_cores = devices[0].num_cores
return TensorCoreMesh(
np.array([TensorCore(i) for i in range(num_cores)]),
[axis_name],
)
def runtime_assert_enabled() -> bool:
"""Returns whether runtime asserts are enabled."""
return _ENABLE_RUNTIME_ASSERT.value
def _tensorcore_mesh_discharge_rule(
in_avals,
out_avals,
*args,
mesh,
jaxpr,
):
del out_avals
assert isinstance(mesh, TensorCoreMesh)
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)
assert len(jaxpr.outvars) == 0
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
out_specs=[pallas_core.BlockSpec(
memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
)(*args)
return out, ()
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
_tensorcore_mesh_discharge_rule
)

View File

@ -50,7 +50,6 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
@ -3078,54 +3077,3 @@ def _iota_2x32_shape_lowering(ctx, *, shape):
lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering
# Lowering for shard_map
# Technically this is not a lowering rule, but a discharge rule. When we use
# a special pallas mesh for a shard_map inside of a run_state, we turn it into
# a pallas call. The pallas_call has named grid axes corresponding to the names
# in the pallas mesh. It also sets up input/output aliasing automatically.
def _shard_map_discharge_rule(
in_avals,
out_avals,
*args,
mesh,
auto,
in_names,
out_names,
jaxpr,
check_rep,
rewrite,
):
del out_avals, auto, in_names, out_names, check_rep, rewrite
if not isinstance(mesh, pallas_core.PallasMesh):
raise NotImplementedError("Mesh must be a PallasMesh")
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
def body(*args):
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, (), *in_refs)
assert len(jaxpr.outvars) == 0
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
out_specs=[pallas_core.BlockSpec(
memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
)(*args)
return out, ()
from jax.experimental import shard_map
state_discharge.register_discharge_rule(shard_map.shard_map_p)(
_shard_map_discharge_rule
)

View File

@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import core_map
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import GridSpec
from jax._src.pallas.core import IndexingMode
@ -53,6 +54,7 @@ from jax._src.pallas.utils import cdiv
from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when
from jax._src.state.discharge import run_state
from jax._src.state.indexing import ds
from jax._src.state.indexing import dslice
from jax._src.state.indexing import Slice

View File

@ -17,9 +17,7 @@ import functools
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.state import discharge as state_discharge
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
@ -51,7 +49,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
_, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
@ -73,7 +71,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
_, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
@ -101,7 +99,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
_, y = pl.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
@ -128,7 +126,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
@jax.jit
def f(x):
_, y, o = state_discharge.run_state(f_stateful)(
_, y, o = pl.run_state(f_stateful)(
(x, jnp.zeros_like(x), jnp.zeros_like(x))
)
return y, o
@ -178,7 +176,7 @@ class PallasCallStatefulTest(jtu.JaxTestCase):
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
)()
_, _, o = state_discharge.run_state(run_matmul)(
_, _, o = pl.run_state(run_matmul)(
(x, y, jnp.ones((m, n), dtype=x.dtype))
)
return o
@ -202,11 +200,7 @@ class ShmallasTest(jtu.JaxTestCase):
def test_can_create_tensorcore_mesh(self):
_ = pltpu.create_tensorcore_mesh("x")
def test_can_trivially_shard_map_with_pallas_mesh(self):
mesh = pltpu.create_tensorcore_mesh("x")
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
def test_can_run_basic_pallas_kernel_with_shard_map(self):
def test_can_run_basic_pallas_kernel_with_core_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
@ -214,19 +208,18 @@ class ShmallasTest(jtu.JaxTestCase):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
@pl.core_map(mesh)
def _():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
_, y = pl.run_state(inner)((x, y))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
def test_can_query_core_index_pallas_kernel_with_core_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
@ -234,7 +227,8 @@ class ShmallasTest(jtu.JaxTestCase):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
@pl.core_map(mesh)
def _():
num_cores = jax.lax.psum(1, "x")
slc_size = 16 // num_cores
def alloc(x_vmem_ref, y_vmem_ref, sem):
@ -254,9 +248,7 @@ class ShmallasTest(jtu.JaxTestCase):
pltpu.VMEM((slc_size, 128), y_ref.dtype),
pltpu.SemaphoreType.DMA,
)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
_, y = pl.run_state(inner)((x, y))
return y
num_cores = jax.devices()[0].num_cores
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))