From 92326dbc7116dc6351946433b79b43ba3addc658 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 3 Apr 2024 16:12:43 -0700 Subject: [PATCH] Expose `Layout(device_local_layout, sharding)` class allowing users to specify layouts of Arrays. Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put. Note: This currently only works on TPU. PiperOrigin-RevId: 621668247 --- jax/BUILD | 6 +++- jax/_src/api.py | 5 ++-- jax/_src/array.py | 6 ++-- jax/_src/dispatch.py | 56 ++++++++++++++++++++++++----------- jax/_src/interpreters/pxla.py | 17 +++++++---- jax/_src/layout.py | 40 +++++++++++++++++++++++++ jax/_src/pjit.py | 14 ++++++++- jax/_src/stages.py | 6 ++-- jax/experimental/layout.py | 1 + tests/layout_test.py | 54 ++++++++++++++++++++++++++++++--- 10 files changed, 168 insertions(+), 37 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index d440103af..7998d028f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -722,7 +722,11 @@ pytype_strict_library( pytype_strict_library( name = "layout", srcs = ["_src/layout.py"], - deps = ["//jax/_src/lib"], + deps = [ + ":sharding", + ":sharding_impls", + "//jax/_src/lib", + ], ) pytype_strict_library( diff --git a/jax/_src/api.py b/jax/_src/api.py index eeb7b38d0..084367f25 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -70,6 +70,7 @@ from jax._src.lib import pmap_lib from jax._src.sharding import Sharding from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, XLACompatibleSharding) +from jax._src.layout import Layout from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps @@ -2461,8 +2462,8 @@ def _check_sharding(x, s): def device_put( x, - device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None): + device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None): """Transfers ``x`` to ``device``. Args: diff --git a/jax/_src/array.py b/jax/_src/array.py index a894d5421..4c85f1fad 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -531,13 +531,13 @@ class ArrayImpl(basearray.Array): @property def layout(self): - # TODO(yashkatariya): Remove the try;except when pathways supports layouts. try: - return layout.DeviceLocalLayout(self._pjrt_layout) + return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout), + self.sharding) except xe.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return None + return layout.Layout(None, self.sharding) else: raise diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 550032c9a..1fdaa659a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -51,6 +51,7 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, GSPMDSharding, TransferToMemoryKind) +from jax._src.layout import Layout, DeviceLocalLayout JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -380,25 +381,9 @@ def _mcjax_reshard(x, target_sharding): pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment -def _device_put_impl( - x, - device: Device | Sharding | None = None, - src: Device | Sharding | None = None): +def _device_put_sharding_impl(x, aval, device): from jax._src import array - if (isinstance(device, TransferToMemoryKind) or - isinstance(src, TransferToMemoryKind)): - raise ValueError( - "TransferToMemoryKind argument to jax.device_put can only be used" - " inside jax.jit. If you are using device_put outside jax.jit, then" - " please provide a concrete Sharding with memory_kind.") - - try: - aval = xla.abstractify(x) - except TypeError as err: - raise TypeError( - f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err - if isinstance(device, Sharding): s = device if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False): @@ -435,6 +420,43 @@ def _device_put_impl( if device is None else device) return _put_x(x, sh, aval, device is not None) +def _device_put_impl( + x, + device: Device | Sharding | Layout | None = None, + src: Device | Sharding | Layout | None = None): + if (isinstance(device, TransferToMemoryKind) or + isinstance(src, TransferToMemoryKind)): + raise ValueError( + "TransferToMemoryKind argument to jax.device_put can only be used" + " inside jax.jit. If you are using device_put outside jax.jit, then" + " please provide a concrete Sharding with memory_kind.") + + try: + aval = xla.abstractify(x) + except TypeError as err: + raise TypeError( + f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err + + if isinstance(device, Layout): + l = device + dll = l.device_local_layout + x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None + if (not isinstance(l.sharding, Sharding) or + not isinstance(dll, (DeviceLocalLayout, type(None)))): + raise ValueError( + "sharding and device_local_layout in `Layout` instance should be" + f" concrete. Got layout: {l}") + if getattr(x, 'layout', None) == l and getattr(x, '_committed', False): + return x + if x_dll is None and dll is None: + return _device_put_sharding_impl(x, aval, l.sharding) + # TODO(yashkatariya): Pass layout to out_shardings directly and remove + # out_layouts from lower. + return api.jit(_identity_fn, out_shardings=l.sharding).lower( + x, _out_layouts=dll).compile()(x) + + return _device_put_sharding_impl(x, aval, device) + device_put_p = core.Primitive('device_put') device_put_p.def_impl(_device_put_impl) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 5dc21a61d..b9cdc58df 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -60,7 +60,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, AutoLayout +from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir @@ -2624,7 +2624,8 @@ def _get_layouts_from_executable( if isinstance(i, DeviceLocalLayout): if i != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)") + f"Unexpected XLA layout override: (XLA) {x} != {i} (User input" + " layout)") new_in_layouts.append(i) else: new_in_layouts.append(x) @@ -2635,7 +2636,8 @@ def _get_layouts_from_executable( if isinstance(o, DeviceLocalLayout): if o != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)") + f"Unexpected XLA layout override: (XLA) {x} != {o} (User output" + " layout)") new_out_layouts.append(o) else: new_out_layouts.append(x) @@ -3072,10 +3074,12 @@ class MeshExecutable(stages.XlaExecutable): return self._out_shardings def input_layouts(self): - return self._in_layouts + return [Layout(l, s) + for l, s in safe_zip(self._in_layouts, self._in_shardings)] def output_layouts(self): - return self._out_layouts + return [Layout(l, s) + for l, s in safe_zip(self._out_layouts, self._out_shardings)] def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and @@ -3254,7 +3258,8 @@ def check_array_xla_sharding_layout_match( 'sharding')) if (xla_extension_version >= 249 and not db_xs and arg._committed and - arg.layout is not None and xl is not None and arg.layout != xl): + arg.layout.device_local_layout is not None and xl is not None and + arg.layout.device_local_layout != xl): errors.append( ("Got input layout(s) that compiled object was called with: " f"{arg.layout} and layout(s) the computation was compiled " diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 1a75a9779..cc4200296 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -14,6 +14,10 @@ from __future__ import annotations +from typing import Union + +from jax._src.sharding import Sharding +from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.lib import xla_client as xc @@ -45,3 +49,39 @@ class AutoLayout: return "AUTO" AUTO = AutoLayout() + + +LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] +ShardingOptions = Union[Sharding, None, AutoSharding] + + +class Layout: + __slots__ = ['device_local_layout', 'sharding'] + + def __init__(self, device_local_layout: LayoutOptions, + sharding: ShardingOptions): + # If layout is concrete and sharding is not, error. + if (isinstance(device_local_layout, DeviceLocalLayout) and + (sharding is None or is_auto(sharding))): + raise ValueError( + 'Sharding has to be concrete when layout is of type' + f' {type(device_local_layout)}. Please pass a' + ' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or' + ' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got' + f' sharding {sharding}' + ) + self.device_local_layout = device_local_layout + self.sharding = sharding + + def __repr__(self): + return (f'Layout(device_local_layout={self.device_local_layout},' + f' sharding={self.sharding})') + + def __hash__(self): + return hash((self.device_local_layout, self.sharding)) + + def __eq__(self, other): + if not isinstance(other, Layout): + return False + return (self.device_local_layout == other.device_local_layout and + self.sharding == other.sharding) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a8c0e2041..ab4d870b7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -67,6 +67,7 @@ from jax._src.sharding_impls import ( SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) +from jax._src.layout import Layout, LayoutOptions from jax._src.state import discharge as state_discharge, RefEffect from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( @@ -437,6 +438,7 @@ def _make_jit_wrapper(jit_info: PjitInfo): args_flat, params['in_shardings'], params['out_shardings'], mesh) in_layouts_flat = _resolve_in_layouts( args_flat, in_layouts_flat, in_shardings) + out_layouts_flat = _resolve_out_layouts(out_layouts_flat) lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], @@ -1268,8 +1270,10 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings): resolved_in_layouts = [] for arg, jit_in_l in safe_zip(args, jit_in_layouts): arg_layout, committed = ( - (arg.layout, getattr(arg, '_committed', True)) + (arg.layout.device_local_layout, getattr(arg, '_committed', True)) if getattr(arg, 'layout', None) is not None else (None, False)) + jit_in_l = (jit_in_l.device_local_layout + if isinstance(jit_in_l, Layout) else jit_in_l) if jit_in_l is None: if committed: resolved_in_layouts.append(arg_layout) @@ -1286,6 +1290,14 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings): return tuple(resolved_in_layouts) +def _resolve_out_layouts(out_layouts: Sequence[Layout] + ) -> Sequence[LayoutOptions]: + # TODO(yashkatariya): Remove the if condition when all layouts come via the + # `layout.Layout` API. + return tuple(o.device_local_layout if isinstance(o, Layout) else o + for o in out_layouts) + + def _resolve_in_shardings( args, pjit_in_shardings: Sequence[PjitSharding], out_shardings: Sequence[PjitSharding], diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 74367476b..9dd882d36 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -44,7 +44,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src.tree_util import tree_unflatten, keystr from jax._src import util -from jax._src.layout import DeviceLocalLayout +from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib import xla_client as xc @@ -513,7 +513,7 @@ class Compiled(Stage): def _input_layouts(self): layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat) + assert all(isinstance(l, Layout) for l in layouts_flat) # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) @@ -523,7 +523,7 @@ class Compiled(Stage): def _output_layouts(self): layouts_flat = self._executable.output_layouts() - assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat) + assert all(isinstance(l, Layout) for l in layouts_flat) return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error @staticmethod diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index d7c650876..e70b82a4f 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -15,4 +15,5 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, AUTO as AUTO, + Layout as Layout ) diff --git a/tests/layout_test.py b/tests/layout_test.py index 6bf7e3087..73259d1cf 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -20,9 +20,10 @@ import numpy as np import jax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config from jax._src import layout +from jax._src.layout import Layout from jax._src import test_util as jtu from jax._src.util import safe_zip from jax._src import xla_bridge @@ -115,7 +116,7 @@ class LayoutTest(jtu.JaxTestCase): self.assertEqual(init_count[0], 1) self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0]) - self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0]) + self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) @@ -223,8 +224,10 @@ class LayoutTest(jtu.JaxTestCase): self.assertArraysEqual(out1, out3) self.assertArraysEqual(out2, out4) - # TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array - # and then pass that back into compiled. + arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)] + out5, out6 = jax.jit(f)(*arrs) + self.assertArraysEqual(out1, out5) + self.assertArraysEqual(out2, out6) def test_aot_layout_mismatch(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -259,6 +262,49 @@ class LayoutTest(jtu.JaxTestCase): jax.jit(jnp.dot, backend=jax.default_backend()).lower( out_cpu, out_cpu).compile() # doesn't crash + def test_device_put_concrete_layout(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (8, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + compiled = jax.jit( + lambda x: x * 2).lower(arr, _out_layouts=layout.AUTO).compile() + col = compiled._output_layouts() + + out = jax.device_put(np_inp, col) + self.assertEqual(out.layout, col) + self.assertArraysEqual(out, np_inp) + for s in out.addressable_shards: + self.assertEqual(out.layout.device_local_layout, + s.data.layout.device_local_layout) + + def test_device_put_non_concrete_layout_error(self): + np_inp = np.arange(16).reshape(8, 2) + + l1 = Layout(layout.AUTO, SingleDeviceSharding(jax.devices()[0])) + with self.assertRaisesRegex( + ValueError, 'sharding and device_local_layout.*should be concrete'): + jax.device_put(np_inp, l1) + + l2 = Layout(layout.AUTO, None) + with self.assertRaisesRegex( + ValueError, 'sharding and device_local_layout.*should be concrete'): + jax.device_put(np_inp, l2) + + l3 = Layout(None, SingleDeviceSharding(jax.devices()[0])) + out = jax.device_put(np_inp, l3) + self.assertArraysEqual(out, np_inp) + self.assertTrue(out._committed) + + def invalid_layout_spec(self): + x = np.arange(8) + compiled = jax.jit(lambda x: x).lower(x).compile() + with self.assertRaisesRegex( + ValueError, 'Sharding has to be concrete when layout.*'): + Layout(compiled._output_layouts()[0], None) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())