mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Read the layout set by with_sharding_constraint
and set the top module level out_layout
to AUTO
if wsc layout is not None.
This will allow XLA to override the entry_computation_layout with the layout set via custom call (i.e. via wsc). PiperOrigin-RevId: 648911765
This commit is contained in:
parent
f089ecc47a
commit
884487773e
@ -2108,6 +2108,36 @@ def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr,
|
||||
return tuple(safe_map(read, jaxpr.outvars))
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
|
||||
) -> tuple[None | DeviceLocalLayout]:
|
||||
from jax._src import pjit
|
||||
|
||||
env = {} # type: ignore
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
|
||||
def read(var):
|
||||
if type(var) is core.Literal:
|
||||
return None
|
||||
return env[var]
|
||||
|
||||
def write(var, val):
|
||||
env[var] = val
|
||||
|
||||
safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars))
|
||||
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
# TODO(yashkatariya): Replace this with a registration system when there are
|
||||
# more primitives for layout propagation.
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
out_eqn_layouts = [eqn.params['layout']]
|
||||
else:
|
||||
out_eqn_layouts = [None] * len(eqn.outvars)
|
||||
safe_map(write, eqn.outvars, out_eqn_layouts)
|
||||
return tuple(safe_map(read, jaxpr.outvars))
|
||||
|
||||
|
||||
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
@ -2199,6 +2229,13 @@ def lower_sharding_computation(
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
|
||||
# If layout is propagated, then set the out_layout in the top module to AUTO
|
||||
# so that XLA can override the entry_computation_layout. The propagated
|
||||
# layout will be set via a custom call.
|
||||
out_layouts_via_prop = get_out_layouts_via_propagation(closed_jaxpr)
|
||||
out_layouts = tuple(DeviceLocalLayout.AUTO if p is not None else o
|
||||
for o, p in safe_zip(out_layouts, out_layouts_via_prop))
|
||||
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
|
@ -369,10 +369,7 @@ def _split_layout_and_sharding(entries):
|
||||
layouts, shardings = [], []
|
||||
|
||||
for e in entries_flat:
|
||||
if e is None or is_unspecified_or_auto(e):
|
||||
layouts.append(None)
|
||||
shardings.append(e)
|
||||
elif isinstance(e, Layout):
|
||||
if isinstance(e, Layout):
|
||||
layouts.append(e.device_local_layout)
|
||||
shardings.append(e.sharding)
|
||||
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
|
||||
@ -1430,7 +1427,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
for arg, jit_in_l, rs, aval in safe_zip(
|
||||
args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
arg_layout, committed = (
|
||||
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
|
||||
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l,
|
||||
rs, aval),
|
||||
getattr(arg, '_committed', True))
|
||||
# Sharding can be unspecified when array is committed if it's a PmapSharding.
|
||||
is_pmap_sharding = (is_unspecified(rs) or
|
||||
|
@ -16,7 +16,6 @@ import contextlib
|
||||
import math
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -359,7 +358,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def test_wsc_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (128, 128)
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -367,11 +366,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))
|
||||
|
||||
# We need AUTO so that XLA can override the entry computation layout set.
|
||||
# TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by
|
||||
# default instead of `None` i.e. default layout and let the compiler choose
|
||||
# the layout or try setting it to AUTO by default and see if there is chaos.
|
||||
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x.T
|
||||
# Constrain `y` to the original layout of `arr` because without it,
|
||||
@ -383,9 +378,9 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.layout, arr.layout)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
|
||||
def test_wsc_concrete_layout_bfloat16(self):
|
||||
def test_wsc_bfloat16_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (128, 128)
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
||||
arr = jax.device_put(inp, s)
|
||||
@ -393,7 +388,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))
|
||||
|
||||
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x.T
|
||||
# Constrain `y` to the original layout of `arr` because without it,
|
||||
|
Loading…
x
Reference in New Issue
Block a user