From 884487773e58bc25ddd30cbe7080f043c1a591f6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 2 Jul 2024 19:12:27 -0700 Subject: [PATCH] Read the layout set by `with_sharding_constraint` and set the top module level `out_layout` to `AUTO` if wsc layout is not None. This will allow XLA to override the entry_computation_layout with the layout set via custom call (i.e. via wsc). PiperOrigin-RevId: 648911765 --- jax/_src/interpreters/pxla.py | 37 +++++++++++++++++++++++++++++++++++ jax/_src/pjit.py | 8 +++----- tests/layout_test.py | 15 +++++--------- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 378aa2c5f..7a40fec27 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2108,6 +2108,36 @@ def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr, return tuple(safe_map(read, jaxpr.outvars)) +@weakref_lru_cache +def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr + ) -> tuple[None | DeviceLocalLayout]: + from jax._src import pjit + + env = {} # type: ignore + jaxpr = closed_jaxpr.jaxpr + + def read(var): + if type(var) is core.Literal: + return None + return env[var] + + def write(var, val): + env[var] = val + + safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars)) + safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) + + for eqn in jaxpr.eqns: + # TODO(yashkatariya): Replace this with a registration system when there are + # more primitives for layout propagation. + if eqn.primitive is pjit.sharding_constraint_p: + out_eqn_layouts = [eqn.params['layout']] + else: + out_eqn_layouts = [None] * len(eqn.outvars) + safe_map(write, eqn.outvars, out_eqn_layouts) + return tuple(safe_map(read, jaxpr.outvars)) + + MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] @@ -2199,6 +2229,13 @@ def lower_sharding_computation( global_in_avals = closed_jaxpr.in_avals global_out_avals = closed_jaxpr.out_avals + # If layout is propagated, then set the out_layout in the top module to AUTO + # so that XLA can override the entry_computation_layout. The propagated + # layout will be set via a custom call. + out_layouts_via_prop = get_out_layouts_via_propagation(closed_jaxpr) + out_layouts = tuple(DeviceLocalLayout.AUTO if p is not None else o + for o, p in safe_zip(out_layouts, out_layouts_via_prop)) + assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 454611424..de87b417c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -369,10 +369,7 @@ def _split_layout_and_sharding(entries): layouts, shardings = [], [] for e in entries_flat: - if e is None or is_unspecified_or_auto(e): - layouts.append(None) - shardings.append(e) - elif isinstance(e, Layout): + if isinstance(e, Layout): layouts.append(e.device_local_layout) shardings.append(e.sharding) elif isinstance(e, (DeviceLocalLayout, AutoLayout)): @@ -1430,7 +1427,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): for arg, jit_in_l, rs, aval in safe_zip( args, jit_in_layouts, resolved_in_shardings, in_avals): arg_layout, committed = ( - pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval), + pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, + rs, aval), getattr(arg, '_committed', True)) # Sharding can be unspecified when array is committed if it's a PmapSharding. is_pmap_sharding = (is_unspecified(rs) or diff --git a/tests/layout_test.py b/tests/layout_test.py index d0d0a27b8..f2a29960c 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -16,7 +16,6 @@ import contextlib import math from absl.testing import absltest import numpy as np -from functools import partial import jax import jax.numpy as jnp @@ -359,7 +358,7 @@ class LayoutTest(jtu.JaxTestCase): def test_wsc_concrete_layout(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - shape = (128, 128) + shape = (16, 128) s = NamedSharding(mesh, P('x')) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) @@ -367,11 +366,7 @@ class LayoutTest(jtu.JaxTestCase): # Create a custom layout instead of using `arr.layout` to test the API. custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) - # We need AUTO so that XLA can override the entry computation layout set. - # TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by - # default instead of `None` i.e. default layout and let the compiler choose - # the layout or try setting it to AUTO by default and see if there is chaos. - @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + @jax.jit def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, @@ -383,9 +378,9 @@ class LayoutTest(jtu.JaxTestCase): self.assertEqual(out.layout, arr.layout) self.assertArraysEqual(out, np_inp.T) - def test_wsc_concrete_layout_bfloat16(self): + def test_wsc_bfloat16_concrete_layout(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - shape = (128, 128) + shape = (16, 128) s = NamedSharding(mesh, P('x')) inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) arr = jax.device_put(inp, s) @@ -393,7 +388,7 @@ class LayoutTest(jtu.JaxTestCase): # Create a custom layout instead of using `arr.layout` to test the API. custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1))) - @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + @jax.jit def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it,