mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Expose Layout(device_local_layout, sharding)
class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put. Note: This currently only works on TPU. PiperOrigin-RevId: 621668247
This commit is contained in:
parent
24517ca3e0
commit
92326dbc71
@ -722,7 +722,11 @@ pytype_strict_library(
|
||||
pytype_strict_library(
|
||||
name = "layout",
|
||||
srcs = ["_src/layout.py"],
|
||||
deps = ["//jax/_src/lib"],
|
||||
deps = [
|
||||
":sharding",
|
||||
":sharding_impls",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
|
@ -70,6 +70,7 @@ from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
XLACompatibleSharding)
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
|
||||
@ -2461,8 +2462,8 @@ def _check_sharding(x, s):
|
||||
|
||||
def device_put(
|
||||
x,
|
||||
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
|
||||
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
Args:
|
||||
|
@ -531,13 +531,13 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
|
||||
try:
|
||||
return layout.DeviceLocalLayout(self._pjrt_layout)
|
||||
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
|
||||
self.sharding)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
return None
|
||||
return layout.Layout(None, self.sharding)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -51,6 +51,7 @@ from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
GSPMDSharding, TransferToMemoryKind)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
@ -380,25 +381,9 @@ def _mcjax_reshard(x, target_sharding):
|
||||
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
|
||||
|
||||
|
||||
def _device_put_impl(
|
||||
x,
|
||||
device: Device | Sharding | None = None,
|
||||
src: Device | Sharding | None = None):
|
||||
def _device_put_sharding_impl(x, aval, device):
|
||||
from jax._src import array
|
||||
|
||||
if (isinstance(device, TransferToMemoryKind) or
|
||||
isinstance(src, TransferToMemoryKind)):
|
||||
raise ValueError(
|
||||
"TransferToMemoryKind argument to jax.device_put can only be used"
|
||||
" inside jax.jit. If you are using device_put outside jax.jit, then"
|
||||
" please provide a concrete Sharding with memory_kind.")
|
||||
|
||||
try:
|
||||
aval = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
|
||||
if isinstance(device, Sharding):
|
||||
s = device
|
||||
if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False):
|
||||
@ -435,6 +420,43 @@ def _device_put_impl(
|
||||
if device is None else device)
|
||||
return _put_x(x, sh, aval, device is not None)
|
||||
|
||||
def _device_put_impl(
|
||||
x,
|
||||
device: Device | Sharding | Layout | None = None,
|
||||
src: Device | Sharding | Layout | None = None):
|
||||
if (isinstance(device, TransferToMemoryKind) or
|
||||
isinstance(src, TransferToMemoryKind)):
|
||||
raise ValueError(
|
||||
"TransferToMemoryKind argument to jax.device_put can only be used"
|
||||
" inside jax.jit. If you are using device_put outside jax.jit, then"
|
||||
" please provide a concrete Sharding with memory_kind.")
|
||||
|
||||
try:
|
||||
aval = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
|
||||
if isinstance(device, Layout):
|
||||
l = device
|
||||
dll = l.device_local_layout
|
||||
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
|
||||
if (not isinstance(l.sharding, Sharding) or
|
||||
not isinstance(dll, (DeviceLocalLayout, type(None)))):
|
||||
raise ValueError(
|
||||
"sharding and device_local_layout in `Layout` instance should be"
|
||||
f" concrete. Got layout: {l}")
|
||||
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
|
||||
return x
|
||||
if x_dll is None and dll is None:
|
||||
return _device_put_sharding_impl(x, aval, l.sharding)
|
||||
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
|
||||
# out_layouts from lower.
|
||||
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
|
||||
x, _out_layouts=dll).compile()(x)
|
||||
|
||||
return _device_put_sharding_impl(x, aval, device)
|
||||
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
device_put_p.def_impl(_device_put_impl)
|
||||
|
@ -60,7 +60,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -2624,7 +2624,8 @@ def _get_layouts_from_executable(
|
||||
if isinstance(i, DeviceLocalLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User input"
|
||||
" layout)")
|
||||
new_in_layouts.append(i)
|
||||
else:
|
||||
new_in_layouts.append(x)
|
||||
@ -2635,7 +2636,8 @@ def _get_layouts_from_executable(
|
||||
if isinstance(o, DeviceLocalLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User output"
|
||||
" layout)")
|
||||
new_out_layouts.append(o)
|
||||
else:
|
||||
new_out_layouts.append(x)
|
||||
@ -3072,10 +3074,12 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return self._out_shardings
|
||||
|
||||
def input_layouts(self):
|
||||
return self._in_layouts
|
||||
return [Layout(l, s)
|
||||
for l, s in safe_zip(self._in_layouts, self._in_shardings)]
|
||||
|
||||
def output_layouts(self):
|
||||
return self._out_layouts
|
||||
return [Layout(l, s)
|
||||
for l, s in safe_zip(self._out_layouts, self._out_shardings)]
|
||||
|
||||
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
|
||||
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
||||
@ -3254,7 +3258,8 @@ def check_array_xla_sharding_layout_match(
|
||||
'sharding'))
|
||||
|
||||
if (xla_extension_version >= 249 and not db_xs and arg._committed and
|
||||
arg.layout is not None and xl is not None and arg.layout != xl):
|
||||
arg.layout.device_local_layout is not None and xl is not None and
|
||||
arg.layout.device_local_layout != xl):
|
||||
errors.append(
|
||||
("Got input layout(s) that compiled object was called with: "
|
||||
f"{arg.layout} and layout(s) the computation was compiled "
|
||||
|
@ -14,6 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
@ -45,3 +49,39 @@ class AutoLayout:
|
||||
return "AUTO"
|
||||
|
||||
AUTO = AutoLayout()
|
||||
|
||||
|
||||
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
|
||||
ShardingOptions = Union[Sharding, None, AutoSharding]
|
||||
|
||||
|
||||
class Layout:
|
||||
__slots__ = ['device_local_layout', 'sharding']
|
||||
|
||||
def __init__(self, device_local_layout: LayoutOptions,
|
||||
sharding: ShardingOptions):
|
||||
# If layout is concrete and sharding is not, error.
|
||||
if (isinstance(device_local_layout, DeviceLocalLayout) and
|
||||
(sharding is None or is_auto(sharding))):
|
||||
raise ValueError(
|
||||
'Sharding has to be concrete when layout is of type'
|
||||
f' {type(device_local_layout)}. Please pass a'
|
||||
' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or'
|
||||
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
|
||||
f' sharding {sharding}'
|
||||
)
|
||||
self.device_local_layout = device_local_layout
|
||||
self.sharding = sharding
|
||||
|
||||
def __repr__(self):
|
||||
return (f'Layout(device_local_layout={self.device_local_layout},'
|
||||
f' sharding={self.sharding})')
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.device_local_layout, self.sharding))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Layout):
|
||||
return False
|
||||
return (self.device_local_layout == other.device_local_layout and
|
||||
self.sharding == other.sharding)
|
||||
|
@ -67,6 +67,7 @@ from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.layout import Layout, LayoutOptions
|
||||
from jax._src.state import discharge as state_discharge, RefEffect
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import (
|
||||
@ -437,6 +438,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
in_layouts_flat = _resolve_in_layouts(
|
||||
args_flat, in_layouts_flat, in_shardings)
|
||||
out_layouts_flat = _resolve_out_layouts(out_layouts_flat)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
@ -1268,8 +1270,10 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
resolved_in_layouts = []
|
||||
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
|
||||
arg_layout, committed = (
|
||||
(arg.layout, getattr(arg, '_committed', True))
|
||||
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
|
||||
if getattr(arg, 'layout', None) is not None else (None, False))
|
||||
jit_in_l = (jit_in_l.device_local_layout
|
||||
if isinstance(jit_in_l, Layout) else jit_in_l)
|
||||
if jit_in_l is None:
|
||||
if committed:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
@ -1286,6 +1290,14 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
return tuple(resolved_in_layouts)
|
||||
|
||||
|
||||
def _resolve_out_layouts(out_layouts: Sequence[Layout]
|
||||
) -> Sequence[LayoutOptions]:
|
||||
# TODO(yashkatariya): Remove the if condition when all layouts come via the
|
||||
# `layout.Layout` API.
|
||||
return tuple(o.device_local_layout if isinstance(o, Layout) else o
|
||||
for o in out_layouts)
|
||||
|
||||
|
||||
def _resolve_in_shardings(
|
||||
args, pjit_in_shardings: Sequence[PjitSharding],
|
||||
out_shardings: Sequence[PjitSharding],
|
||||
|
@ -44,7 +44,7 @@ from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import tree_unflatten, keystr
|
||||
from jax._src import util
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -513,7 +513,7 @@ class Compiled(Stage):
|
||||
|
||||
def _input_layouts(self):
|
||||
layouts_flat = self._executable.input_layouts()
|
||||
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
|
||||
assert all(isinstance(l, Layout) for l in layouts_flat)
|
||||
# Some input layouts got DCE'd
|
||||
if self.in_tree.num_leaves > len(layouts_flat):
|
||||
iter_layouts_flat = iter(layouts_flat)
|
||||
@ -523,7 +523,7 @@ class Compiled(Stage):
|
||||
|
||||
def _output_layouts(self):
|
||||
layouts_flat = self._executable.output_layouts()
|
||||
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
|
||||
assert all(isinstance(l, Layout) for l in layouts_flat)
|
||||
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
@staticmethod
|
||||
|
@ -15,4 +15,5 @@
|
||||
from jax._src.layout import (
|
||||
DeviceLocalLayout as DeviceLocalLayout,
|
||||
AUTO as AUTO,
|
||||
Layout as Layout
|
||||
)
|
||||
|
@ -20,9 +20,10 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
|
||||
from jax._src import config
|
||||
from jax._src import layout
|
||||
from jax._src.layout import Layout
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import xla_bridge
|
||||
@ -115,7 +116,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(init_count[0], 1)
|
||||
|
||||
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1])
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
||||
apply_out = compiled_apply(*init_out)
|
||||
@ -223,8 +224,10 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out1, out3)
|
||||
self.assertArraysEqual(out2, out4)
|
||||
|
||||
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
|
||||
# and then pass that back into compiled.
|
||||
arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)]
|
||||
out5, out6 = jax.jit(f)(*arrs)
|
||||
self.assertArraysEqual(out1, out5)
|
||||
self.assertArraysEqual(out2, out6)
|
||||
|
||||
def test_aot_layout_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -259,6 +262,49 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
|
||||
out_cpu, out_cpu).compile() # doesn't crash
|
||||
|
||||
def test_device_put_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
compiled = jax.jit(
|
||||
lambda x: x * 2).lower(arr, _out_layouts=layout.AUTO).compile()
|
||||
col = compiled._output_layouts()
|
||||
|
||||
out = jax.device_put(np_inp, col)
|
||||
self.assertEqual(out.layout, col)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(out.layout.device_local_layout,
|
||||
s.data.layout.device_local_layout)
|
||||
|
||||
def test_device_put_non_concrete_layout_error(self):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
l1 = Layout(layout.AUTO, SingleDeviceSharding(jax.devices()[0]))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
||||
jax.device_put(np_inp, l1)
|
||||
|
||||
l2 = Layout(layout.AUTO, None)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
||||
jax.device_put(np_inp, l2)
|
||||
|
||||
l3 = Layout(None, SingleDeviceSharding(jax.devices()[0]))
|
||||
out = jax.device_put(np_inp, l3)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertTrue(out._committed)
|
||||
|
||||
def invalid_layout_spec(self):
|
||||
x = np.arange(8)
|
||||
compiled = jax.jit(lambda x: x).lower(x).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Sharding has to be concrete when layout.*'):
|
||||
Layout(compiled._output_layouts()[0], None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user