mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Standardize default layout to None
in internals (dispatch, lowering and compilation) and non-default layouts to concrete layouts.
This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX. PiperOrigin-RevId: 668527139
This commit is contained in:
parent
46957052c5
commit
ef33cf5ace
@ -751,8 +751,7 @@ def make_array_from_callback(
|
||||
and sharding.is_fully_replicated
|
||||
and first_value.is_fully_replicated
|
||||
and first_value.sharding._device_assignment == tuple(devices)
|
||||
and (first_value.layout.device_local_layout ==
|
||||
pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))):
|
||||
and first_value.layout.device_local_layout == dll):
|
||||
return first_value
|
||||
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
@ -1105,11 +1104,6 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
||||
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
|
||||
return dst_indices, tuple(src_indices) == tuple(dst_indices)
|
||||
|
||||
def _layout_eq(x, dst_layout, sharding):
|
||||
if pxla.is_default_layout(dst_layout, sharding, x.aval):
|
||||
return True
|
||||
return x.layout.device_local_layout == dst_layout
|
||||
|
||||
|
||||
def _array_shard_arg(xs, shardings, layouts):
|
||||
results = []
|
||||
@ -1118,7 +1112,8 @@ def _array_shard_arg(xs, shardings, layouts):
|
||||
for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
|
||||
x._check_if_deleted()
|
||||
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
|
||||
same_layout = _layout_eq(x, layout, sharding)
|
||||
same_layout = (True if layout is None else
|
||||
x.layout.device_local_layout == layout)
|
||||
|
||||
if not x.is_fully_addressable:
|
||||
if same_indices and same_layout:
|
||||
|
@ -1053,6 +1053,7 @@ def lower_jaxpr_to_module(
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
|
||||
# TODO(yashkatariya): Simplify the donation logic.
|
||||
xla_donated_args = None
|
||||
platforms_with_donation = [p for p in platforms
|
||||
if p in _platforms_with_donation]
|
||||
@ -1071,9 +1072,6 @@ def lower_jaxpr_to_module(
|
||||
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
|
||||
input_output_aliases, in_avals, out_avals, donated_args,
|
||||
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts)
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
if any(donated_args):
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
@ -1082,10 +1080,13 @@ def lower_jaxpr_to_module(
|
||||
if unused_donations:
|
||||
warnings.warn("Some donated buffers were not usable:"
|
||||
f" {', '.join(unused_donations)}.\n{msg}")
|
||||
|
||||
# Delete donated_args by default here, since it's not needed beyond this point
|
||||
del donated_args
|
||||
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
|
||||
# HLO channels need to start at 1
|
||||
channel_iter = itertools.count(1)
|
||||
# Create a keepalives list that will be mutated during the lowering.
|
||||
@ -1167,8 +1168,7 @@ def lower_jaxpr_to_module(
|
||||
|
||||
|
||||
def _set_up_aliases(input_output_aliases, avals_in, avals_out,
|
||||
donated_args,
|
||||
arg_memory_kinds, result_memory_kinds,
|
||||
donated_args, arg_memory_kinds, result_memory_kinds,
|
||||
in_layouts, out_layouts):
|
||||
if input_output_aliases is None:
|
||||
input_output_aliases = [None] * len(avals_in)
|
||||
@ -1207,15 +1207,14 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,
|
||||
if donations.get(key, ()):
|
||||
input_id = donations[key].popleft()
|
||||
out_donated_args[input_id] = False
|
||||
# We can alias if XLA performs layout assignment because XLA will
|
||||
# respect the aliases when assigning layouts. Its only for two
|
||||
# mismatched explicitly assigned layouts that XLA will certainly fail.
|
||||
if (in_layouts is None or
|
||||
out_layouts is None or
|
||||
in_layouts[input_id] == out_layouts[i] or
|
||||
# We can alias if XLA performs layout assignment because XLA will
|
||||
# respect the aliases when assigning layouts. Its only for two
|
||||
# mismatched explicitly assigned layouts that XLA will certainly
|
||||
# fail.
|
||||
isinstance(in_layouts[input_id], (AutoLayout, type(None))) or
|
||||
isinstance(out_layouts[i], (AutoLayout, type(None)))):
|
||||
isinstance(in_layouts[input_id], AutoLayout) or
|
||||
isinstance(out_layouts[i], AutoLayout)):
|
||||
input_output_aliases[input_id] = i
|
||||
else:
|
||||
# Fallback to xla donation if layouts don't match.
|
||||
|
@ -150,7 +150,7 @@ shard_arg_handlers: dict[
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def is_default_layout(curr_layout, sharding, aval):
|
||||
if curr_layout is None or sharding is None:
|
||||
if curr_layout is None or sharding is None or is_unspecified(sharding):
|
||||
return True
|
||||
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
@ -191,7 +191,7 @@ def _shard_np_array(xs, shardings, layouts):
|
||||
if x.dtype == dtypes.float0:
|
||||
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
||||
aval = api_util.shaped_abstractify(x)
|
||||
if not is_default_layout(layout, sharding, aval):
|
||||
if layout is not None:
|
||||
results.append(api.device_put(x, Layout(layout, sharding)))
|
||||
else:
|
||||
if sharding.is_fully_replicated:
|
||||
@ -1884,35 +1884,6 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
|
||||
) -> DeviceLocalLayout | None:
|
||||
if is_unspecified_or_auto(sharding):
|
||||
return None
|
||||
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
|
||||
if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return None
|
||||
if not core.is_constant_shape(aval.shape):
|
||||
return None
|
||||
shard_shape = sharding.shard_shape(aval.shape)
|
||||
d = sharding._device_assignment[0]
|
||||
# If a backend doesn't implement `get_default_layout` return `None` to avoid
|
||||
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
|
||||
# first call you pass it a sharded array with layout and on second call you
|
||||
# pass a numpy array. The layouts should be the same to get cache hits.
|
||||
try:
|
||||
al = DeviceLocalLayout.from_pjrt_layout(
|
||||
d.client.get_default_layout(aval.dtype, shard_shape, d))
|
||||
except:
|
||||
return None
|
||||
# argument does not have `.layout` property. ShapedArray, numpy array, etc
|
||||
# are some examples.
|
||||
if arg_layout is None:
|
||||
return al if jit_in_layout is None else arg_layout # arg_layout is None
|
||||
# If arg has a `.layout` property, then return device_local_layout as is.
|
||||
return arg_layout.device_local_layout
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
@ -2775,13 +2746,14 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx: set[int]
|
||||
mut: MutationData | None
|
||||
auto_spmd_lowering: bool
|
||||
in_layouts: Sequence[DeviceLocalLayout | None]
|
||||
out_layouts: Sequence[DeviceLocalLayout | None]
|
||||
xla_in_layouts: Sequence[DeviceLocalLayout | None]
|
||||
dispatch_in_layouts: Sequence[DeviceLocalLayout | None]
|
||||
xla_out_layouts: Sequence[DeviceLocalLayout | None]
|
||||
all_args_info: AllArgsInfo | None
|
||||
pgle_profiler: profiler.PGLEProfiler | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
handle_args = InputsHandler(self.input_shardings, self.in_layouts)
|
||||
handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed)
|
||||
|
||||
@ -2797,8 +2769,8 @@ class UnloadedMeshExecutable:
|
||||
self.input_avals, self.output_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.in_layouts, self.out_layouts,
|
||||
self.all_args_info, self)
|
||||
self.xla_in_layouts, self.dispatch_in_layouts,
|
||||
self.xla_out_layouts, self.all_args_info, self)
|
||||
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
@ -2881,8 +2853,18 @@ class UnloadedMeshExecutable:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
|
||||
in_layouts, out_layouts = _get_layouts_from_executable(
|
||||
# xla_in_layouts are all either None or DeviceLocalLayout. Even default
|
||||
# layout are concrete layouts and they are used in `compiled.input_layouts`
|
||||
# to return concrete layouts to users.
|
||||
# `dispatch_in_layouts` replaces default layouts with `None` to simplify
|
||||
# dispatch logic downstream.
|
||||
xla_in_layouts, xla_out_layouts = _get_layouts_from_executable(
|
||||
xla_executable, in_layouts, out_layouts, len(ordered_effects))
|
||||
del in_layouts, out_layouts
|
||||
dispatch_in_layouts = [
|
||||
None if is_default_layout(l, s, a) else l
|
||||
for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals)
|
||||
]
|
||||
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
@ -2907,8 +2889,9 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx=kept_var_idx,
|
||||
mut=mut,
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
xla_in_layouts=xla_in_layouts,
|
||||
dispatch_in_layouts=dispatch_in_layouts,
|
||||
xla_out_layouts=xla_out_layouts,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler).load()
|
||||
|
||||
@ -2964,13 +2947,13 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
|
||||
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
|
||||
"_unloaded_executable",
|
||||
"_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts",
|
||||
"_xla_out_layouts", "_all_args_info", "_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
in_layouts, out_layouts,
|
||||
xla_in_layouts, dispatch_in_layouts, xla_out_layouts,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
@ -2984,8 +2967,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._in_layouts = in_layouts
|
||||
self._out_layouts = out_layouts
|
||||
self._xla_in_layouts = xla_in_layouts
|
||||
self._dispatch_in_layouts = dispatch_in_layouts
|
||||
self._xla_out_layouts = xla_out_layouts
|
||||
self._all_args_info = all_args_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@ -3013,9 +2997,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
|
||||
all_arg_avals = map(xla.abstractify, kept_args)
|
||||
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_array_xla_sharding_layout_match(
|
||||
args_after_dce, self._in_shardings, self._in_layouts, debug_info,
|
||||
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,
|
||||
self._kept_var_idx)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
@ -3027,11 +3010,11 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
|
||||
def input_layouts(self):
|
||||
return [Layout(l, s)
|
||||
for l, s in safe_zip(self._in_layouts, self._in_shardings)]
|
||||
for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)]
|
||||
|
||||
def output_layouts(self):
|
||||
return [Layout(l, s)
|
||||
for l, s in safe_zip(self._out_layouts, self._out_shardings)]
|
||||
for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)]
|
||||
|
||||
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
|
||||
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
||||
@ -3057,12 +3040,10 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
else s
|
||||
for s, a in zip(self._in_shardings, self.in_avals)
|
||||
]
|
||||
in_dlls = get_layouts_for_fasthpath_data(
|
||||
self._in_layouts, in_shardings, self.in_avals)
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree_dispatch, in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
in_dlls)
|
||||
self._dispatch_in_layouts)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
@ -3084,16 +3065,6 @@ else:
|
||||
return shard_args([sharding], [layout], [x])[0]
|
||||
|
||||
|
||||
def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals):
|
||||
in_dlls = []
|
||||
for l, s, a in zip(in_layouts, in_shardings, in_avals):
|
||||
if is_default_layout(l, s, a):
|
||||
in_dlls.append(None)
|
||||
else:
|
||||
in_dlls.append(l)
|
||||
return in_dlls
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
|
@ -279,12 +279,10 @@ def _get_fastpath_data(
|
||||
else s
|
||||
for s, a in zip(executable._in_shardings, executable.in_avals)
|
||||
]
|
||||
in_dlls = pxla.get_layouts_for_fasthpath_data(
|
||||
executable._in_layouts, in_shardings, executable.in_avals)
|
||||
fastpath_data = pxla.MeshExecutableFastpathData(
|
||||
executable.xla_executable, out_tree, in_shardings,
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
in_dlls)
|
||||
executable._dispatch_in_layouts)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return fastpath_data
|
||||
@ -1479,10 +1477,17 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
resolved_in_layouts = []
|
||||
for arg, jit_in_l, rs, aval in safe_zip(
|
||||
args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
arg_layout, committed = (
|
||||
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l,
|
||||
rs, aval),
|
||||
getattr(arg, '_committed', True))
|
||||
committed = getattr(arg, '_committed', True)
|
||||
# `arg_layout` is only used for checking purposes in the `else` branch
|
||||
# below. We cannot replace default layout with None to raise nicer errors.
|
||||
# `dispatch_arg_layout` replaces default layouts with `None` to simplify
|
||||
# dispatch and lowering logic downstream.
|
||||
if hasattr(arg, 'layout'):
|
||||
arg_layout = arg.layout.device_local_layout
|
||||
dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval)
|
||||
else arg_layout)
|
||||
else:
|
||||
arg_layout, dispatch_arg_layout = None, None
|
||||
# Sharding can be unspecified when array is committed if it's a PmapSharding.
|
||||
is_pmap_sharding = (is_unspecified(rs) or
|
||||
isinstance(getattr(arg, 'sharding', None), PmapSharding))
|
||||
@ -1491,7 +1496,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
if is_pmap_sharding:
|
||||
resolved_in_layouts.append(None)
|
||||
else:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
resolved_in_layouts.append(dispatch_arg_layout)
|
||||
else:
|
||||
resolved_in_layouts.append(None)
|
||||
else:
|
||||
|
@ -540,7 +540,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
out = f(arr)
|
||||
f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_auto(self):
|
||||
@ -555,7 +555,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
out = f(arr)
|
||||
f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_matching_in_and_out(self):
|
||||
@ -572,9 +572,27 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
out = f(arr)
|
||||
f(arr)
|
||||
self.assertTrue(arr.is_deleted())
|
||||
|
||||
def test_layout_donation_mismatching_in_and_out_fails(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16*2, 32016*2)
|
||||
np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
||||
|
||||
custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1)))
|
||||
l1 = Layout(custom_dll1, s)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(jax.jit, out_shardings=l1, donate_argnums=0)
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
||||
f.lower(sds).compile()(arr)
|
||||
self.assertFalse(arr.is_deleted())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user