mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...]
and tiling: tuple[tuple[int, ...], ...]
as the arguments. Allows users to pass layouts to with_sharding_constraint
to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required. memory space is exposed via JAX memories API so it doesn't have to be in the layout API. Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts. Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information. PiperOrigin-RevId: 647487510
This commit is contained in:
parent
d577e29998
commit
e1a496d3b6
@ -800,10 +800,11 @@ pytype_strict_library(
|
||||
name = "layout",
|
||||
srcs = ["_src/layout.py"],
|
||||
deps = [
|
||||
":dtypes",
|
||||
":sharding",
|
||||
":sharding_impls",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
|
@ -506,7 +506,8 @@ class ArrayImpl(basearray.Array):
|
||||
if self.is_deleted():
|
||||
return Layout(None, self.sharding)
|
||||
try:
|
||||
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
|
||||
return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout),
|
||||
self.sharding)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
|
@ -824,12 +824,15 @@ def _to_physical_op_sharding(
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
||||
|
||||
|
||||
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
|
||||
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
|
||||
aval: core.AbstractValue) -> str | None:
|
||||
if layout is None:
|
||||
return "default"
|
||||
if isinstance(layout, AutoLayout):
|
||||
return "auto"
|
||||
return layout._to_xla_layout()
|
||||
if aval is core.abstract_token:
|
||||
return "default"
|
||||
return layout._to_xla_layout(aval.dtype) # type: ignore
|
||||
|
||||
|
||||
def _get_mem_kind(s: JSharding | None) -> str | None:
|
||||
@ -1194,11 +1197,9 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
ir_arg_shardings = None
|
||||
if arg_shardings is not None:
|
||||
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
|
||||
ir_arg_shardings = util.flatten(
|
||||
[[_to_physical_op_sharding(a, s)] * len(types)
|
||||
for a, s, types in zip(in_avals, arg_shardings, input_types)])
|
||||
del in_avals
|
||||
for a, s, types in zip(input_avals, arg_shardings, input_types)])
|
||||
|
||||
ir_arg_memory_kinds = None
|
||||
if arg_memory_kinds is not None:
|
||||
@ -1208,8 +1209,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_arg_layouts = None
|
||||
if arg_layouts is not None:
|
||||
ir_arg_layouts = util.flatten(
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(arg_layouts, input_types)])
|
||||
[[_to_xla_layout(l, a)] * len(types)
|
||||
for l, a, types in zip(arg_layouts, input_avals, input_types)])
|
||||
|
||||
ir_donated_args = None
|
||||
if xla_donated_args is not None:
|
||||
@ -1244,8 +1245,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_result_layouts = None
|
||||
if result_layouts is not None:
|
||||
ir_result_layouts = util.flatten(
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(result_layouts, output_types)])
|
||||
[[_to_xla_layout(l, a)] * len(types)
|
||||
for l, a, types in zip(result_layouts, output_avals, output_types)])
|
||||
|
||||
if (
|
||||
replicated_args is not None
|
||||
@ -2171,6 +2172,29 @@ def get_sharding_attr(sharding_proto: xc.OpSharding):
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
||||
|
||||
|
||||
def wrap_with_layout_op(ctx: LoweringRuleContext,
|
||||
x: ir.Value,
|
||||
aval_out: core.AbstractValue,
|
||||
layout: DeviceLocalLayout,
|
||||
aval_in: core.AbstractValue):
|
||||
result_type = aval_to_ir_type(aval_out)
|
||||
out_shape = core.physical_aval(aval_out).shape # type: ignore
|
||||
if core.is_constant_shape(out_shape):
|
||||
result_shapes = None
|
||||
else:
|
||||
result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]
|
||||
|
||||
op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x],
|
||||
api_version=1,
|
||||
result_shapes=result_shapes,
|
||||
# Set operand layouts to anything. XLA will ignore it.
|
||||
operand_layouts=[list(range(aval_in.ndim))], # type: ignore
|
||||
# TODO(yashkatariya): Figure out how to pass tiling to the
|
||||
# custom call.
|
||||
result_layouts=[layout.major_to_minor[::-1]])
|
||||
return op.result
|
||||
|
||||
|
||||
# MLIR lowerings for lax primitives
|
||||
|
||||
def cache_lowering(f):
|
||||
|
@ -1937,7 +1937,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
|
||||
# first call you pass it a sharded array with layout and on second call you
|
||||
# pass a numpy array. The layouts should be the same to get cache hits.
|
||||
try:
|
||||
al = DeviceLocalLayout(
|
||||
al = DeviceLocalLayout.from_pjrt_layout(
|
||||
d.client.get_default_layout(aval.dtype, shard_shape, d))
|
||||
except:
|
||||
return None
|
||||
@ -2704,7 +2704,7 @@ def _get_layouts_from_executable(
|
||||
|
||||
new_in_layouts = []
|
||||
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
||||
x = DeviceLocalLayout(x)
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(i, DeviceLocalLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
@ -2716,7 +2716,7 @@ def _get_layouts_from_executable(
|
||||
|
||||
new_out_layouts = []
|
||||
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
||||
x = DeviceLocalLayout(x)
|
||||
x = DeviceLocalLayout.from_pjrt_layout(x)
|
||||
if isinstance(o, DeviceLocalLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
|
@ -16,9 +16,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from jax._src.dtypes import iinfo, issubdtype
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
class AutoLayout:
|
||||
@ -27,31 +30,78 @@ class AutoLayout:
|
||||
return "AUTO"
|
||||
|
||||
|
||||
class DeviceLocalLayout:
|
||||
layout: xc.PjRtLayout
|
||||
if xla_extension_version >= 274:
|
||||
class DeviceLocalLayout:
|
||||
major_to_minor: tuple[int, ...]
|
||||
tiling: tuple[tuple[int, ...], ...] | None
|
||||
|
||||
AUTO = AutoLayout()
|
||||
AUTO = AutoLayout()
|
||||
|
||||
def __init__(self, layout: xc.PjRtLayout):
|
||||
self._layout = layout
|
||||
self._layout_str = str(self._layout)
|
||||
def __init__(self, major_to_minor: tuple[int, ...],
|
||||
tiling: tuple[tuple[int, ...], ...] | None = None):
|
||||
self.major_to_minor = tuple(major_to_minor)
|
||||
self.tiling = None if tiling is None else tuple(map(tuple, tiling))
|
||||
|
||||
def __repr__(self):
|
||||
return f'DeviceLocalLayout({self._layout_str})'
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
xla_layout = pjrt_layout._xla_layout()
|
||||
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
|
||||
xla_layout.tiling())
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._layout)
|
||||
def __repr__(self):
|
||||
return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
|
||||
f' tiling={self.tiling})')
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return self._layout == other._layout
|
||||
def __hash__(self):
|
||||
return hash((self.major_to_minor, self.tiling))
|
||||
|
||||
def _to_xla_layout(self) -> str:
|
||||
return self._layout_str
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return (self.major_to_minor == other.major_to_minor and
|
||||
self.tiling == other.tiling)
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
if self.tiling is None:
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1])
|
||||
else:
|
||||
if issubdtype(dtype, np.integer):
|
||||
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
|
||||
else:
|
||||
sub_byte_size = 0
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore
|
||||
sub_byte_size)
|
||||
return str(xla_layout)
|
||||
else:
|
||||
class DeviceLocalLayout: # type: ignore
|
||||
layout: xc.PjRtLayout
|
||||
|
||||
AUTO = AutoLayout()
|
||||
|
||||
def __init__(self, layout: xc.PjRtLayout):
|
||||
self._layout = layout
|
||||
self._layout_str = str(self._layout)
|
||||
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
return DeviceLocalLayout(pjrt_layout) # type: ignore
|
||||
|
||||
def __repr__(self):
|
||||
return f'DeviceLocalLayout({self._layout_str})'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._layout)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return self._layout == other._layout
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
return self._layout_str
|
||||
|
||||
|
||||
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
|
||||
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
|
||||
ShardingOptions = Union[Sharding, None, AutoSharding]
|
||||
|
||||
|
||||
|
@ -1816,6 +1816,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
[tmpvar], [outvar], sharding_constraint_p,
|
||||
dict(resource_env=resource_env,
|
||||
sharding=gspmd_sharding,
|
||||
layout=None,
|
||||
unconstrained_dims=unconstrained_dims),
|
||||
set(),
|
||||
eqn.source_info, eqn.ctx))
|
||||
|
@ -2488,6 +2488,9 @@ def with_sharding_constraint(x, shardings):
|
||||
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
|
||||
"""
|
||||
x_flat, tree = tree_flatten(x)
|
||||
|
||||
layouts, shardings = _split_layout_and_sharding(shardings)
|
||||
|
||||
user_shardings = prepare_axis_resources(
|
||||
shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del shardings
|
||||
@ -2496,6 +2499,10 @@ def with_sharding_constraint(x, shardings):
|
||||
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
|
||||
del user_shardings
|
||||
|
||||
user_layouts_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint layouts", tree, layouts))
|
||||
del layouts
|
||||
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
mesh = resource_env.physical_mesh
|
||||
|
||||
@ -2511,19 +2518,27 @@ def with_sharding_constraint(x, shardings):
|
||||
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
|
||||
allow_uneven_sharding=True)
|
||||
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=s,
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
|
||||
resource_env=resource_env,
|
||||
unconstrained_dims=ud)
|
||||
for xf, s, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat,
|
||||
unconstrained_dims)]
|
||||
return tree_unflatten(tree, outs)
|
||||
|
||||
def _identity_fn(x): return x
|
||||
|
||||
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
|
||||
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
|
||||
return x
|
||||
# Run a jit here to raise good errors when device assignment don't match.
|
||||
return api.jit(_identity_fn, out_shardings=sharding)(x)
|
||||
def _sharding_constraint_impl(x, sharding, layout, resource_env,
|
||||
unconstrained_dims):
|
||||
if layout is None:
|
||||
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
|
||||
return x
|
||||
# Run a jit here to raise good errors when device assignment don't match.
|
||||
return api.jit(_identity_fn, out_shardings=sharding)(x)
|
||||
else:
|
||||
if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and
|
||||
x.sharding.is_equivalent_to(sharding, x.ndim)):
|
||||
return x
|
||||
return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x)
|
||||
|
||||
|
||||
sharding_constraint_p = core.Primitive("sharding_constraint")
|
||||
@ -2532,7 +2547,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
|
||||
ad.deflinear2(sharding_constraint_p,
|
||||
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
|
||||
|
||||
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout,
|
||||
resource_env, unconstrained_dims):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
@ -2547,19 +2562,19 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0]
|
||||
sharding = NamedSharding._from_parsed_pspec(
|
||||
mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes)
|
||||
return [
|
||||
mlir.wrap_with_sharding_op(ctx,
|
||||
x_node, out_aval,
|
||||
sharding._to_xla_hlo_sharding(aval.ndim).to_proto(),
|
||||
unspecified_dims=unconstrained_dims)
|
||||
]
|
||||
out = mlir.wrap_with_sharding_op(
|
||||
ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(),
|
||||
unspecified_dims=unconstrained_dims)
|
||||
if layout is not None:
|
||||
out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval)
|
||||
return [out]
|
||||
mlir.register_lowering(sharding_constraint_p,
|
||||
_sharding_constraint_hlo_lowering)
|
||||
|
||||
|
||||
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
|
||||
axis_name, main_type, vals_in, dims_in,
|
||||
sharding, resource_env, unconstrained_dims):
|
||||
def _sharding_constraint_batcher(
|
||||
insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in,
|
||||
dims_in, sharding, layout, resource_env, unconstrained_dims):
|
||||
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
|
||||
used = {n for ns in sharding.spec
|
||||
for n in (ns if isinstance(ns, tuple) else (ns,))}
|
||||
@ -2586,9 +2601,14 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
|
||||
vmapped_sharding = NamedSharding(
|
||||
vmapped_sharding.mesh, PartitionSpec(*new_spec))
|
||||
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
if layout is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
y = sharding_constraint_p.bind(
|
||||
x,
|
||||
sharding=vmapped_sharding,
|
||||
layout=layout,
|
||||
resource_env=resource_env,
|
||||
unconstrained_dims=unconstrained_dims)
|
||||
return y, d
|
||||
|
@ -17,7 +17,6 @@ import asyncio
|
||||
import contextlib
|
||||
import math
|
||||
from functools import partial
|
||||
import re
|
||||
import os
|
||||
import pathlib
|
||||
import tracemalloc as tm
|
||||
@ -46,13 +45,6 @@ def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
pattern = re.compile(r"\{(.*?):")
|
||||
|
||||
def extract_minor_to_major(l):
|
||||
match = re.search(pattern, str(l))
|
||||
return tuple(int(i) for i in match.groups()[0].split(','))
|
||||
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
|
||||
@ -436,8 +428,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower(
|
||||
arr).compile().output_layouts()
|
||||
self.assertEqual(extract_minor_to_major(arr.layout),
|
||||
extract_minor_to_major(out_layout)[::-1])
|
||||
self.assertEqual(arr.layout.device_local_layout.major_to_minor,
|
||||
out_layout.device_local_layout.major_to_minor[::-1])
|
||||
|
||||
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
|
||||
ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
|
||||
|
@ -14,9 +14,9 @@
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -37,15 +37,6 @@ def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
pattern = re.compile(r"\{(.*?):")
|
||||
|
||||
# Extract minor_to_major from str(layout) because layout doesn't have a
|
||||
# minor_to_major property yet.
|
||||
def extract_minor_to_major(l):
|
||||
match = re.search(pattern, str(l))
|
||||
return tuple(int(i) for i in match.groups()[0].split(','))
|
||||
|
||||
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -79,8 +70,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEmpty(kw_layouts)
|
||||
|
||||
for i, o in zip(arg_layouts, compiled_apply.output_layouts()):
|
||||
self.assertEqual(extract_minor_to_major(i),
|
||||
extract_minor_to_major(o)[::-1])
|
||||
self.assertEqual(i.device_local_layout.major_to_minor,
|
||||
o.device_local_layout.major_to_minor[::-1])
|
||||
|
||||
init_compiled = jax.jit(
|
||||
init, out_shardings=arg_layouts).lower(sds1, sds2).compile()
|
||||
@ -108,10 +99,10 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0])
|
||||
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1])
|
||||
|
||||
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
|
||||
extract_minor_to_major(init_out[0].layout)[::-1])
|
||||
self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout),
|
||||
extract_minor_to_major(init_out[1].layout)[::-1])
|
||||
self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor,
|
||||
init_out[0].layout.device_local_layout.major_to_minor[::-1])
|
||||
self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor,
|
||||
init_out[1].layout.device_local_layout.major_to_minor[::-1])
|
||||
|
||||
self.assertArraysEqual(init_out[0], np_inp1 * 2)
|
||||
self.assertArraysEqual(init_out[1], np_inp2 * 2)
|
||||
@ -135,18 +126,22 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
out = compiled(arr)
|
||||
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0))
|
||||
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
|
||||
(2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.output_layouts()), (2, 1, 0))
|
||||
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
|
||||
(2, 1, 0))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0))
|
||||
compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
|
||||
(2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2))
|
||||
compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1],
|
||||
(0, 1, 2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jax.jit` does not accept device-local layouts directly"):
|
||||
@ -166,9 +161,11 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
compiled = jax.jit(f, in_shardings=Layout(),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
|
||||
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
|
||||
(1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.output_layouts()), (0, 1))
|
||||
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
|
||||
(0, 1))
|
||||
|
||||
out = compiled(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
@ -185,9 +182,11 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile()
|
||||
out = compiled(np_inp)
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
|
||||
compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1],
|
||||
(1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled.output_layouts()), (0, 1))
|
||||
compiled.output_layouts().device_local_layout.major_to_minor[::-1],
|
||||
(0, 1))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
@ -358,6 +357,64 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
jax.make_array_from_callback(
|
||||
np_inp.shape, Layout(None, None), lambda idx: np_inp[idx])
|
||||
|
||||
def test_wsc_concrete_layout(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (128, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))
|
||||
|
||||
# We need AUTO so that XLA can override the entry computation layout set.
|
||||
# TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by
|
||||
# default instead of `None` i.e. default layout and let the compiler choose
|
||||
# the layout or try setting it to AUTO by default and see if there is chaos.
|
||||
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
|
||||
def f(x):
|
||||
y = x.T
|
||||
# Constrain `y` to the original layout of `arr` because without it,
|
||||
# the layout of `y` would be the transpose of `arr`.
|
||||
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.layout, Layout(custom_dll, s))
|
||||
self.assertEqual(out.layout, arr.layout)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
|
||||
def test_wsc_concrete_layout_bfloat16(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (128, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
# Create a custom layout instead of using `arr.layout` to test the API.
|
||||
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))
|
||||
|
||||
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
|
||||
def f(x):
|
||||
y = x.T
|
||||
# Constrain `y` to the original layout of `arr` because without it,
|
||||
# the layout of `y` would be the transpose of `arr`.
|
||||
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.layout, Layout(custom_dll, s))
|
||||
self.assertEqual(out.layout, arr.layout)
|
||||
self.assertArraysEqual(out, inp.T)
|
||||
|
||||
def test_device_put_user_concrete_layout(self):
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),))
|
||||
s = SingleDeviceSharding(jax.devices()[0])
|
||||
|
||||
out = jax.device_put(np_inp, Layout(dll, s))
|
||||
self.assertEqual(out.layout, Layout(dll, s))
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user