Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.

This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
This commit is contained in:
Yash Katariya 2024-04-03 18:36:44 -07:00 committed by jax authors
parent d790c88da9
commit 5cbb26f36d
5 changed files with 38 additions and 23 deletions

View File

@ -441,11 +441,13 @@ def _device_put_impl(
l = device
dll = l.device_local_layout
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
if dll is None and l.sharding is None:
return _device_put_sharding_impl(x, aval, l.sharding)
if (not isinstance(l.sharding, Sharding) or
not isinstance(dll, (DeviceLocalLayout, type(None)))):
raise ValueError(
"sharding and device_local_layout in `Layout` instance should be"
f" concrete. Got layout: {l}")
f" concrete. Got layout: {l} for input {aval.str_short()}")
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
return x
if x_dll is None and dll is None:
@ -453,7 +455,7 @@ def _device_put_impl(
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
# out_layouts from lower.
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
x, _out_layouts=dll).compile()(x)
x, _out_layouts=l).compile()(x)
return _device_put_sharding_impl(x, aval, device)

View File

@ -58,8 +58,8 @@ ShardingOptions = Union[Sharding, None, AutoSharding]
class Layout:
__slots__ = ['device_local_layout', 'sharding']
def __init__(self, device_local_layout: LayoutOptions,
sharding: ShardingOptions):
def __init__(self, device_local_layout: LayoutOptions = None,
sharding: ShardingOptions = None):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or is_auto(sharding))):
@ -70,6 +70,19 @@ class Layout:
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
f' sharding {sharding}'
)
if not isinstance(
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
raise ValueError(
'Invalid value received for the device_local_layout argument.'
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an instance'
f' of `DeviceLocalLayout`. Got {device_local_layout}')
if not isinstance(
sharding, (Sharding, type(None), AutoSharding)):
raise ValueError(
'Invalid value received for the sharding argument. Expected values'
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
f' {sharding}')
self.device_local_layout = device_local_layout
self.sharding = sharding

View File

@ -425,8 +425,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
# TODO(yashkatariya): Remove this when it's added on jit.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
in_layouts = kwargs.pop('_in_layouts', Layout())
out_layouts = kwargs.pop('_out_layouts', Layout())
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, in_layouts_flat, out_layouts_flat,
arg_names, ()) = _infer_params(
@ -1272,8 +1272,7 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
arg_layout, committed = (
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
if getattr(arg, 'layout', None) is not None else (None, False))
jit_in_l = (jit_in_l.device_local_layout
if isinstance(jit_in_l, Layout) else jit_in_l)
jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout
if jit_in_l is None:
if committed:
resolved_in_layouts.append(arg_layout)
@ -1293,9 +1292,8 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
def _resolve_out_layouts(out_layouts: Sequence[Layout]
) -> Sequence[LayoutOptions]:
# TODO(yashkatariya): Remove the if condition when all layouts come via the
# `layout.Layout` API.
return tuple(o.device_local_layout if isinstance(o, Layout) else o
for o in out_layouts)
# `layout.Layout` API or handle this properly when layout is on jit.
return tuple(None if o is None else o.device_local_layout for o in out_layouts)
def _resolve_in_shardings(

View File

@ -518,7 +518,7 @@ class Compiled(Stage):
if self.in_tree.num_leaves > len(layouts_flat):
iter_layouts_flat = iter(layouts_flat)
layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx
else None for i in range(self.in_tree.num_leaves)]
else Layout() for i in range(self.in_tree.num_leaves)]
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
def _output_layouts(self):

View File

@ -89,7 +89,7 @@ class LayoutTest(jtu.JaxTestCase):
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
lowered_apply = jax.jit(apply).lower(
sds1, sds2, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO)
sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO))
compiled_apply = lowered_apply.compile()
arg_layouts, kw_layouts = compiled_apply._input_layouts()
@ -158,8 +158,8 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
compiled_auto = jax.jit(f).lower(sds, _in_layouts=DLL.AUTO,
_out_layouts=DLL.AUTO).compile()
compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
self.assertTupleEqual(
@ -176,7 +176,7 @@ class LayoutTest(jtu.JaxTestCase):
return x.T
compiled = jax.jit(f).lower(
arr, _in_layouts=None, _out_layouts=DLL.AUTO).compile()
arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
self.assertTupleEqual(
@ -194,7 +194,8 @@ class LayoutTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
np_inp, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
np_inp, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
out = compiled(np_inp)
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
@ -209,8 +210,8 @@ class LayoutTest(jtu.JaxTestCase):
shape = (8, 2)
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
compiled = jax.jit(f).lower(*inps, _in_layouts=DLL.AUTO,
_out_layouts=DLL.AUTO).compile()
compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
arg_layouts, _ = compiled._input_layouts()
out1, out2 = compiled(*inps)
@ -243,10 +244,11 @@ class LayoutTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
'Layout passed to jit does not match the layout on the respective arg'):
jax.jit(f).lower(arr, _in_layouts=DLL.AUTO)
jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO))
compiled = jax.jit(f).lower(
sds, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
sds, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
with self.assertRaisesRegex(
ValueError,
@ -269,7 +271,7 @@ class LayoutTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
compiled = jax.jit(
lambda x: x * 2).lower(arr, _out_layouts=DLL.AUTO).compile()
lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile()
col = compiled._output_layouts()
out = jax.device_put(np_inp, col)
@ -287,7 +289,7 @@ class LayoutTest(jtu.JaxTestCase):
ValueError, 'sharding and device_local_layout.*should be concrete'):
jax.device_put(np_inp, l1)
l2 = Layout(DLL.AUTO, None)
l2 = Layout(DLL.AUTO)
with self.assertRaisesRegex(
ValueError, 'sharding and device_local_layout.*should be concrete'):
jax.device_put(np_inp, l2)