mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Allow tracing and lowering (with lowering_platforms specified) to work with an AbstractMesh. Such a computation cannot be compiled.
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered. PiperOrigin-RevId: 705764178
This commit is contained in:
parent
0e7f218eb0
commit
d0f63da4b5
@ -41,7 +41,6 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import profiler
|
||||
@ -65,6 +64,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping,
|
||||
@ -98,7 +98,6 @@ ShardedAxis = sharding_specs.ShardedAxis
|
||||
Replicated = sharding_specs.Replicated
|
||||
|
||||
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
||||
Mesh = mesh_lib.Mesh
|
||||
MeshAxisName = sharding_impls.MeshAxisName
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
@ -1723,20 +1722,19 @@ def _get_and_check_device_assignment(
|
||||
devices: Sequence[xc.Device] | None,
|
||||
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
|
||||
first_sharding_info = None
|
||||
if devices is None:
|
||||
devices = ()
|
||||
else:
|
||||
devices = tuple(devices)
|
||||
devices = () if devices is None else tuple(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if isinstance(i, UnspecifiedValue):
|
||||
for sh, s_type, source_info in shardings:
|
||||
if isinstance(sh, UnspecifiedValue):
|
||||
continue
|
||||
if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
|
||||
continue
|
||||
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (
|
||||
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
|
||||
else (i._device_assignment, s_type, source_info))
|
||||
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
|
||||
(sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO)
|
||||
else (sh._device_assignment, s_type, source_info))
|
||||
arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO)
|
||||
else sh._device_assignment)
|
||||
if not devices:
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
@ -1837,7 +1835,8 @@ class SemanticallyEqualShardings:
|
||||
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
|
||||
avals: tuple[core.AbstractValue]):
|
||||
gspmd_shardings = [
|
||||
s if isinstance(s, (UnspecifiedValue, AUTO))
|
||||
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
|
||||
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
|
||||
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
|
||||
for s, a in zip(shardings, avals)]
|
||||
self._gspmd_shardings = gspmd_shardings
|
||||
@ -1895,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
abstract_mesh: mesh_lib.AbstractMesh | None):
|
||||
abstract_mesh: AbstractMesh | None):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -2082,6 +2081,40 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
|
||||
return tuple(safe_map(read, jaxpr.outvars))
|
||||
|
||||
|
||||
def _get_num_devices(shardings, device_assignment, lowering_platforms,
|
||||
prim_requires_devices) -> int:
|
||||
ext_abstract_mesh, concrete_sharding = None, False
|
||||
for s in shardings:
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
continue
|
||||
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
||||
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
|
||||
raise ValueError("AbstractMesh should be the same across all "
|
||||
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
|
||||
ext_abstract_mesh = s.mesh
|
||||
else:
|
||||
concrete_sharding = True
|
||||
if (concrete_sharding and ext_abstract_mesh is not None and
|
||||
len(device_assignment) != ext_abstract_mesh.size):
|
||||
raise ValueError(
|
||||
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
|
||||
f" device assignment size: {len(device_assignment)}")
|
||||
if concrete_sharding:
|
||||
return len(device_assignment)
|
||||
if ext_abstract_mesh is None:
|
||||
return len(device_assignment)
|
||||
if lowering_platforms is None:
|
||||
raise ValueError(
|
||||
"Passing lowering_platforms via"
|
||||
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
|
||||
" only AbstractMesh exists in a jitted computation.")
|
||||
if prim_requires_devices:
|
||||
raise ValueError(
|
||||
"AbstractMesh cannot be used when jaxpr contains primitives that"
|
||||
" require devices to be present during lowering.")
|
||||
return ext_abstract_mesh.size
|
||||
|
||||
|
||||
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
@ -2126,7 +2159,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _abstract_to_concrete_mesh(abstract_mesh):
|
||||
return mesh_lib.Mesh(
|
||||
return Mesh(
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
|
||||
axis_types=abstract_mesh.axis_types)
|
||||
|
||||
@ -2153,7 +2186,7 @@ def lower_sharding_computation(
|
||||
donated_invars: Sequence[bool],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
context_mesh: mesh_lib.Mesh | None,
|
||||
context_mesh: Mesh | None,
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
@ -2211,6 +2244,7 @@ def lower_sharding_computation(
|
||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in unique_intermediate_shardings)),
|
||||
devices_from_context)
|
||||
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
out_shardings = _concretize_abstract_shardings(
|
||||
@ -2221,12 +2255,23 @@ def lower_sharding_computation(
|
||||
platforms = lowering_platforms or (
|
||||
getattr(backend, "_raw_platform", backend.platform),)
|
||||
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
# TODO(yashkatariya): All device specific logic should go in compilation
|
||||
# but this requires a big refactor. The current `_get_num_devices` logic
|
||||
# is good enough to lower with AbstractMesh but cannot be compiled. Once
|
||||
# I refactor, this will also work well with mesh being provided at
|
||||
# compile time.
|
||||
num_devices = _get_num_devices(
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
unique_intermediate_shardings),
|
||||
device_assignment, lowering_platforms, prim_requires_devices)
|
||||
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
|
||||
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
|
||||
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
|
||||
devices_from_context
|
||||
or num_devices > 1
|
||||
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
|
||||
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
@ -2234,8 +2279,7 @@ def lower_sharding_computation(
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings],
|
||||
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||
|
||||
if all_default_mem_kind:
|
||||
propagated_out_mem_kinds = (None,) * len(global_out_avals)
|
||||
@ -2244,12 +2288,11 @@ def lower_sharding_computation(
|
||||
closed_jaxpr, in_shardings)
|
||||
|
||||
# 2. Build up the HLO
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
abstract_mesh = None
|
||||
if prim_requires_devices:
|
||||
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
unique_intermediate_shardings):
|
||||
if isinstance(sharding, NamedSharding):
|
||||
if (abstract_mesh is not None and
|
||||
abstract_mesh != sharding.mesh.abstract_mesh):
|
||||
@ -2267,7 +2310,7 @@ def lower_sharding_computation(
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||
semantic_out_shardings, in_layouts, out_layouts, num_devices,
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, platforms,
|
||||
@ -2310,7 +2353,7 @@ def lower_sharding_computation(
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler,
|
||||
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
|
||||
intermediate_shardings=unique_intermediate_shardings,
|
||||
context_mesh=context_mesh)
|
||||
|
||||
|
||||
@ -2480,7 +2523,7 @@ def _register_out_sharding_handler(
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
|
||||
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
|
||||
assert isinstance(orig_in_s.mesh, Mesh)
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
|
||||
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
|
||||
@ -2532,7 +2575,7 @@ def _get_out_sharding_from_orig_sharding(
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
@ -2832,8 +2875,14 @@ class UnloadedMeshExecutable:
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
context_mesh: mesh_lib.Mesh | None = None
|
||||
context_mesh: Mesh | None = None,
|
||||
) -> MeshExecutable:
|
||||
if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
|
||||
for s in it.chain(in_shardings, out_shardings)):
|
||||
raise RuntimeError(
|
||||
"A jitted computation cannot contain AbstractMesh in in_shardings and"
|
||||
" out_shardings during compilation. You can use `jax.export` to "
|
||||
" lower with an AbstractMesh and later compile with concrete devices.")
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
if isinstance(device_assignment, xc.DeviceList):
|
||||
|
@ -498,7 +498,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
|
||||
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
|
||||
pgle_profiler=None)
|
||||
pgle_profiler=None)
|
||||
return stages.Traced(
|
||||
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
|
||||
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
|
||||
|
@ -4631,6 +4631,54 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ins, _ = f.lower(np.arange(8)).compile().input_shardings
|
||||
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def test_abstract_mesh_lower(self):
|
||||
mesh = jtu.create_mesh((2,), 'x')
|
||||
mesh2 = jtu.create_mesh((1,), 'x')
|
||||
|
||||
abstract_sds = jax.ShapeDtypeStruct(
|
||||
(8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x')))
|
||||
abstract_sds2 = jax.ShapeDtypeStruct(
|
||||
(8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x')))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
lowered = f.trace(abstract_sds).lower(lowering_platforms=('tpu',))
|
||||
self.assertIn('num_partitions = 2', lowered.as_text())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
|
||||
lowered.compile()
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
return x, y
|
||||
|
||||
concrete_s = NamedSharding(mesh, P('x'))
|
||||
concrete_sds = jax.ShapeDtypeStruct((8,), jnp.float32, sharding=concrete_s)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'AbstractMesh size: 1 does not match the device assignment size: 2'):
|
||||
g.lower(abstract_sds2, concrete_sds)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Passing lowering_platforms.*is required"):
|
||||
g.lower(abstract_sds, np.arange(8))
|
||||
|
||||
lowered2 = g.trace(abstract_sds, np.arange(8)).lower(
|
||||
lowering_platforms=('tpu',))
|
||||
self.assertIn('num_partitions = 2', lowered2.as_text())
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
|
||||
lowered2.compile()
|
||||
|
||||
lowered3 = g.lower(abstract_sds, concrete_sds)
|
||||
self.assertIn('num_partitions = 2', lowered3.as_text())
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
|
||||
lowered3.compile()
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user