mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Make device_local_layout and sharding optional in Layout
. Also only accept Layout
class to _in_layouts
and _out_layouts
.
This is in preparation to get `jax.jit` to accept `Layout`. PiperOrigin-RevId: 621697750
This commit is contained in:
parent
d790c88da9
commit
5cbb26f36d
@ -441,11 +441,13 @@ def _device_put_impl(
|
||||
l = device
|
||||
dll = l.device_local_layout
|
||||
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
|
||||
if dll is None and l.sharding is None:
|
||||
return _device_put_sharding_impl(x, aval, l.sharding)
|
||||
if (not isinstance(l.sharding, Sharding) or
|
||||
not isinstance(dll, (DeviceLocalLayout, type(None)))):
|
||||
raise ValueError(
|
||||
"sharding and device_local_layout in `Layout` instance should be"
|
||||
f" concrete. Got layout: {l}")
|
||||
f" concrete. Got layout: {l} for input {aval.str_short()}")
|
||||
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
|
||||
return x
|
||||
if x_dll is None and dll is None:
|
||||
@ -453,7 +455,7 @@ def _device_put_impl(
|
||||
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
|
||||
# out_layouts from lower.
|
||||
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
|
||||
x, _out_layouts=dll).compile()(x)
|
||||
x, _out_layouts=l).compile()(x)
|
||||
|
||||
return _device_put_sharding_impl(x, aval, device)
|
||||
|
||||
|
@ -58,8 +58,8 @@ ShardingOptions = Union[Sharding, None, AutoSharding]
|
||||
class Layout:
|
||||
__slots__ = ['device_local_layout', 'sharding']
|
||||
|
||||
def __init__(self, device_local_layout: LayoutOptions,
|
||||
sharding: ShardingOptions):
|
||||
def __init__(self, device_local_layout: LayoutOptions = None,
|
||||
sharding: ShardingOptions = None):
|
||||
# If layout is concrete and sharding is not, error.
|
||||
if (isinstance(device_local_layout, DeviceLocalLayout) and
|
||||
(sharding is None or is_auto(sharding))):
|
||||
@ -70,6 +70,19 @@ class Layout:
|
||||
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
|
||||
f' sharding {sharding}'
|
||||
)
|
||||
if not isinstance(
|
||||
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
|
||||
raise ValueError(
|
||||
'Invalid value received for the device_local_layout argument.'
|
||||
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an instance'
|
||||
f' of `DeviceLocalLayout`. Got {device_local_layout}')
|
||||
if not isinstance(
|
||||
sharding, (Sharding, type(None), AutoSharding)):
|
||||
raise ValueError(
|
||||
'Invalid value received for the sharding argument. Expected values'
|
||||
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
|
||||
f' {sharding}')
|
||||
|
||||
self.device_local_layout = device_local_layout
|
||||
self.sharding = sharding
|
||||
|
||||
|
@ -425,8 +425,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
# TODO(yashkatariya): Remove this when it's added on jit.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
in_layouts = kwargs.pop('_in_layouts', Layout())
|
||||
out_layouts = kwargs.pop('_out_layouts', Layout())
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
arg_names, ()) = _infer_params(
|
||||
@ -1272,8 +1272,7 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
arg_layout, committed = (
|
||||
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
|
||||
if getattr(arg, 'layout', None) is not None else (None, False))
|
||||
jit_in_l = (jit_in_l.device_local_layout
|
||||
if isinstance(jit_in_l, Layout) else jit_in_l)
|
||||
jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout
|
||||
if jit_in_l is None:
|
||||
if committed:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
@ -1293,9 +1292,8 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
def _resolve_out_layouts(out_layouts: Sequence[Layout]
|
||||
) -> Sequence[LayoutOptions]:
|
||||
# TODO(yashkatariya): Remove the if condition when all layouts come via the
|
||||
# `layout.Layout` API.
|
||||
return tuple(o.device_local_layout if isinstance(o, Layout) else o
|
||||
for o in out_layouts)
|
||||
# `layout.Layout` API or handle this properly when layout is on jit.
|
||||
return tuple(None if o is None else o.device_local_layout for o in out_layouts)
|
||||
|
||||
|
||||
def _resolve_in_shardings(
|
||||
|
@ -518,7 +518,7 @@ class Compiled(Stage):
|
||||
if self.in_tree.num_leaves > len(layouts_flat):
|
||||
iter_layouts_flat = iter(layouts_flat)
|
||||
layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx
|
||||
else None for i in range(self.in_tree.num_leaves)]
|
||||
else Layout() for i in range(self.in_tree.num_leaves)]
|
||||
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
def _output_layouts(self):
|
||||
|
@ -89,7 +89,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
|
||||
|
||||
lowered_apply = jax.jit(apply).lower(
|
||||
sds1, sds2, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO)
|
||||
sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO))
|
||||
compiled_apply = lowered_apply.compile()
|
||||
|
||||
arg_layouts, kw_layouts = compiled_apply._input_layouts()
|
||||
@ -158,8 +158,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
compiled_auto = jax.jit(f).lower(sds, _in_layouts=DLL.AUTO,
|
||||
_out_layouts=DLL.AUTO).compile()
|
||||
compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
@ -176,7 +176,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
return x.T
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
arr, _in_layouts=None, _out_layouts=DLL.AUTO).compile()
|
||||
arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
self.assertTupleEqual(
|
||||
@ -194,7 +194,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
|
||||
np_inp, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
|
||||
np_inp, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
out = compiled(np_inp)
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
@ -209,8 +210,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
shape = (8, 2)
|
||||
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
|
||||
compiled = jax.jit(f).lower(*inps, _in_layouts=DLL.AUTO,
|
||||
_out_layouts=DLL.AUTO).compile()
|
||||
compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
arg_layouts, _ = compiled._input_layouts()
|
||||
out1, out2 = compiled(*inps)
|
||||
|
||||
@ -243,10 +244,11 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
jax.jit(f).lower(arr, _in_layouts=DLL.AUTO)
|
||||
jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO))
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
sds, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
|
||||
sds, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -269,7 +271,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
compiled = jax.jit(
|
||||
lambda x: x * 2).lower(arr, _out_layouts=DLL.AUTO).compile()
|
||||
lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile()
|
||||
col = compiled._output_layouts()
|
||||
|
||||
out = jax.device_put(np_inp, col)
|
||||
@ -287,7 +289,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
||||
jax.device_put(np_inp, l1)
|
||||
|
||||
l2 = Layout(DLL.AUTO, None)
|
||||
l2 = Layout(DLL.AUTO)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
||||
jax.device_put(np_inp, l2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user