Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.

Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
This commit is contained in:
Yash Katariya 2024-04-03 16:12:43 -07:00 committed by jax authors
parent 24517ca3e0
commit 92326dbc71
10 changed files with 168 additions and 37 deletions

View File

@ -722,7 +722,11 @@ pytype_strict_library(
pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],
deps = ["//jax/_src/lib"],
deps = [
":sharding",
":sharding_impls",
"//jax/_src/lib",
],
)
pytype_strict_library(

View File

@ -70,6 +70,7 @@ from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
XLACompatibleSharding)
from jax._src.layout import Layout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
@ -2461,8 +2462,8 @@ def _check_sharding(x, s):
def device_put(
x,
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None):
"""Transfers ``x`` to ``device``.
Args:

View File

@ -531,13 +531,13 @@ class ArrayImpl(basearray.Array):
@property
def layout(self):
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
try:
return layout.DeviceLocalLayout(self._pjrt_layout)
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
self.sharding)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return None
return layout.Layout(None, self.sharding)
else:
raise

View File

@ -51,6 +51,7 @@ from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
GSPMDSharding, TransferToMemoryKind)
from jax._src.layout import Layout, DeviceLocalLayout
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@ -380,25 +381,9 @@ def _mcjax_reshard(x, target_sharding):
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
def _device_put_impl(
x,
device: Device | Sharding | None = None,
src: Device | Sharding | None = None):
def _device_put_sharding_impl(x, aval, device):
from jax._src import array
if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit. If you are using device_put outside jax.jit, then"
" please provide a concrete Sharding with memory_kind.")
try:
aval = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
if isinstance(device, Sharding):
s = device
if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False):
@ -435,6 +420,43 @@ def _device_put_impl(
if device is None else device)
return _put_x(x, sh, aval, device is not None)
def _device_put_impl(
x,
device: Device | Sharding | Layout | None = None,
src: Device | Sharding | Layout | None = None):
if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit. If you are using device_put outside jax.jit, then"
" please provide a concrete Sharding with memory_kind.")
try:
aval = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
if isinstance(device, Layout):
l = device
dll = l.device_local_layout
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
if (not isinstance(l.sharding, Sharding) or
not isinstance(dll, (DeviceLocalLayout, type(None)))):
raise ValueError(
"sharding and device_local_layout in `Layout` instance should be"
f" concrete. Got layout: {l}")
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
return x
if x_dll is None and dll is None:
return _device_put_sharding_impl(x, aval, l.sharding)
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
# out_layouts from lower.
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
x, _out_layouts=dll).compile()(x)
return _device_put_sharding_impl(x, aval, device)
device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)

View File

@ -60,7 +60,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
@ -2624,7 +2624,8 @@ def _get_layouts_from_executable(
if isinstance(i, DeviceLocalLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
f"Unexpected XLA layout override: (XLA) {x} != {i} (User input"
" layout)")
new_in_layouts.append(i)
else:
new_in_layouts.append(x)
@ -2635,7 +2636,8 @@ def _get_layouts_from_executable(
if isinstance(o, DeviceLocalLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
f"Unexpected XLA layout override: (XLA) {x} != {o} (User output"
" layout)")
new_out_layouts.append(o)
else:
new_out_layouts.append(x)
@ -3072,10 +3074,12 @@ class MeshExecutable(stages.XlaExecutable):
return self._out_shardings
def input_layouts(self):
return self._in_layouts
return [Layout(l, s)
for l, s in safe_zip(self._in_layouts, self._in_shardings)]
def output_layouts(self):
return self._out_layouts
return [Layout(l, s)
for l, s in safe_zip(self._out_layouts, self._out_shardings)]
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
@ -3254,7 +3258,8 @@ def check_array_xla_sharding_layout_match(
'sharding'))
if (xla_extension_version >= 249 and not db_xs and arg._committed and
arg.layout is not None and xl is not None and arg.layout != xl):
arg.layout.device_local_layout is not None and xl is not None and
arg.layout.device_local_layout != xl):
errors.append(
("Got input layout(s) that compiled object was called with: "
f"{arg.layout} and layout(s) the computation was compiled "

View File

@ -14,6 +14,10 @@
from __future__ import annotations
from typing import Union
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.lib import xla_client as xc
@ -45,3 +49,39 @@ class AutoLayout:
return "AUTO"
AUTO = AutoLayout()
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
ShardingOptions = Union[Sharding, None, AutoSharding]
class Layout:
__slots__ = ['device_local_layout', 'sharding']
def __init__(self, device_local_layout: LayoutOptions,
sharding: ShardingOptions):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or is_auto(sharding))):
raise ValueError(
'Sharding has to be concrete when layout is of type'
f' {type(device_local_layout)}. Please pass a'
' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or'
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
f' sharding {sharding}'
)
self.device_local_layout = device_local_layout
self.sharding = sharding
def __repr__(self):
return (f'Layout(device_local_layout={self.device_local_layout},'
f' sharding={self.sharding})')
def __hash__(self):
return hash((self.device_local_layout, self.sharding))
def __eq__(self, other):
if not isinstance(other, Layout):
return False
return (self.device_local_layout == other.device_local_layout and
self.sharding == other.sharding)

View File

@ -67,6 +67,7 @@ from jax._src.sharding_impls import (
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.layout import Layout, LayoutOptions
from jax._src.state import discharge as state_discharge, RefEffect
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
@ -437,6 +438,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
args_flat, params['in_shardings'], params['out_shardings'], mesh)
in_layouts_flat = _resolve_in_layouts(
args_flat, in_layouts_flat, in_shardings)
out_layouts_flat = _resolve_out_layouts(out_layouts_flat)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
@ -1268,8 +1270,10 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
resolved_in_layouts = []
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
arg_layout, committed = (
(arg.layout, getattr(arg, '_committed', True))
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
if getattr(arg, 'layout', None) is not None else (None, False))
jit_in_l = (jit_in_l.device_local_layout
if isinstance(jit_in_l, Layout) else jit_in_l)
if jit_in_l is None:
if committed:
resolved_in_layouts.append(arg_layout)
@ -1286,6 +1290,14 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
return tuple(resolved_in_layouts)
def _resolve_out_layouts(out_layouts: Sequence[Layout]
) -> Sequence[LayoutOptions]:
# TODO(yashkatariya): Remove the if condition when all layouts come via the
# `layout.Layout` API.
return tuple(o.device_local_layout if isinstance(o, Layout) else o
for o in out_layouts)
def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],

View File

@ -44,7 +44,7 @@ from jax._src import traceback_util
from jax._src import tree_util
from jax._src.tree_util import tree_unflatten, keystr
from jax._src import util
from jax._src.layout import DeviceLocalLayout
from jax._src.layout import Layout
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
@ -513,7 +513,7 @@ class Compiled(Stage):
def _input_layouts(self):
layouts_flat = self._executable.input_layouts()
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
assert all(isinstance(l, Layout) for l in layouts_flat)
# Some input layouts got DCE'd
if self.in_tree.num_leaves > len(layouts_flat):
iter_layouts_flat = iter(layouts_flat)
@ -523,7 +523,7 @@ class Compiled(Stage):
def _output_layouts(self):
layouts_flat = self._executable.output_layouts()
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
assert all(isinstance(l, Layout) for l in layouts_flat)
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
@staticmethod

View File

@ -15,4 +15,5 @@
from jax._src.layout import (
DeviceLocalLayout as DeviceLocalLayout,
AUTO as AUTO,
Layout as Layout
)

View File

@ -20,9 +20,10 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
from jax._src import config
from jax._src import layout
from jax._src.layout import Layout
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax._src import xla_bridge
@ -115,7 +116,7 @@ class LayoutTest(jtu.JaxTestCase):
self.assertEqual(init_count[0], 1)
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0])
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1])
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
apply_out = compiled_apply(*init_out)
@ -223,8 +224,10 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out1, out3)
self.assertArraysEqual(out2, out4)
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
# and then pass that back into compiled.
arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)]
out5, out6 = jax.jit(f)(*arrs)
self.assertArraysEqual(out1, out5)
self.assertArraysEqual(out2, out6)
def test_aot_layout_mismatch(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -259,6 +262,49 @@ class LayoutTest(jtu.JaxTestCase):
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
out_cpu, out_cpu).compile() # doesn't crash
def test_device_put_concrete_layout(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
compiled = jax.jit(
lambda x: x * 2).lower(arr, _out_layouts=layout.AUTO).compile()
col = compiled._output_layouts()
out = jax.device_put(np_inp, col)
self.assertEqual(out.layout, col)
self.assertArraysEqual(out, np_inp)
for s in out.addressable_shards:
self.assertEqual(out.layout.device_local_layout,
s.data.layout.device_local_layout)
def test_device_put_non_concrete_layout_error(self):
np_inp = np.arange(16).reshape(8, 2)
l1 = Layout(layout.AUTO, SingleDeviceSharding(jax.devices()[0]))
with self.assertRaisesRegex(
ValueError, 'sharding and device_local_layout.*should be concrete'):
jax.device_put(np_inp, l1)
l2 = Layout(layout.AUTO, None)
with self.assertRaisesRegex(
ValueError, 'sharding and device_local_layout.*should be concrete'):
jax.device_put(np_inp, l2)
l3 = Layout(None, SingleDeviceSharding(jax.devices()[0]))
out = jax.device_put(np_inp, l3)
self.assertArraysEqual(out, np_inp)
self.assertTrue(out._committed)
def invalid_layout_spec(self):
x = np.arange(8)
compiled = jax.jit(lambda x: x).lower(x).compile()
with self.assertRaisesRegex(
ValueError, 'Sharding has to be concrete when layout.*'):
Layout(compiled._output_layouts()[0], None)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())