Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:

```
with mesh:
  out = pjit(lambda: 1)()
```

The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.

This is also required for `Shardy` integration.

PiperOrigin-RevId: 658842350
This commit is contained in:
Yash Katariya 2024-08-02 11:04:01 -07:00 committed by jax authors
parent ac52890e3d
commit 958234a9c1
4 changed files with 18 additions and 9 deletions

View File

@ -875,7 +875,7 @@ def _check_lowering(lowering) -> None:
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = [
allowed_compile_args = {
"backend", "platforms", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"mut", "spmd_lowering", "auto_spmd_lowering",
@ -883,7 +883,7 @@ def _check_lowering(lowering) -> None:
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
"pgle_profiler", "intermediate_shardings"]
"pgle_profiler", "intermediate_shardings", "context_mesh"}
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -2114,7 +2114,7 @@ def lower_sharding_computation(
donated_invars: Sequence[bool],
*,
keep_unused: bool,
devices_from_context: Sequence[xc.Device] | None,
context_mesh: mesh_lib.Mesh | None,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None,
@ -2157,6 +2157,8 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
unique_intermediate_shardings = list(util.stable_unique(
@ -2253,7 +2255,8 @@ def lower_sharding_computation(
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=[s for s, _ in unique_intermediate_shardings])
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
context_mesh=context_mesh)
def _to_logical_sharding(
@ -2472,7 +2475,7 @@ def _get_out_sharding_from_orig_sharding(
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None):
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
@ -2487,6 +2490,11 @@ def maybe_recover_user_shardings(
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, i, None)
if context_mesh is not None and not context_mesh.empty:
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)
if isinstance(n, GSPMDSharding) else n
for n in new_shardings]
return new_shardings
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
@ -2775,6 +2783,7 @@ class UnloadedMeshExecutable:
compiler_options=None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: mesh_lib.Mesh | None = None
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
@ -2832,7 +2841,7 @@ class UnloadedMeshExecutable:
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings)
intermediate_shardings, context_mesh)
out_shardings = finalize_out_shardings(out_shardings, da)

View File

@ -1766,9 +1766,7 @@ def _pjit_lower_cached(
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
keep_unused=keep_unused, context_mesh=mesh,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)

View File

@ -2516,6 +2516,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
with mesh:
out = pjit(lambda: 1)()
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
self.assertIsInstance(out, array.ArrayImpl)
self.assertEqual(out, 1)
@ -2873,6 +2874,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
f = pjit(mul, device=jax.devices()[1])
x = jnp.arange(8).reshape(4, 2)
f_out = f(x)
f_out2 = f(f_out)