mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Allow P.UNCONSTRAINED
in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.
This also simplifies the `impl` rule of `sharding_cast`. PiperOrigin-RevId: 707349491
This commit is contained in:
parent
b56dc63160
commit
e854f1657a
@ -49,7 +49,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import AUTO
|
||||
from jax._src.sharding_impls import AUTO, NamedSharding
|
||||
from jax._src.partition_spec import UnconstrainedSingleton
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
@ -1055,6 +1055,21 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
|
||||
assert isinstance(s, JSharding)
|
||||
return s.memory_kind
|
||||
|
||||
def contains_unconstrained(s):
|
||||
return isinstance(s, NamedSharding) and None in s._parsed_pspec
|
||||
|
||||
def all_unconstrained(s, aval):
|
||||
if isinstance(s, NamedSharding):
|
||||
if aval.ndim != len(s._parsed_pspec):
|
||||
return False
|
||||
return all(p is None for p in s._parsed_pspec)
|
||||
return False
|
||||
|
||||
def _get_unconstrained_dimensions(s, aval):
|
||||
us = contains_unconstrained(s)
|
||||
return (us, all_unconstrained(s, aval),
|
||||
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
@ -1114,7 +1129,8 @@ def lower_jaxpr_to_module(
|
||||
f"only {platforms_with_donation} support donation")
|
||||
if (num_partitions > 1 and
|
||||
(result_shardings is None or
|
||||
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
|
||||
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
|
||||
for s in result_shardings))):
|
||||
xla_donated_args = donated_args
|
||||
donated_args = [False] * len(donated_args)
|
||||
if xla_donated_args is None:
|
||||
@ -1448,7 +1464,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_arg_memory_kinds = None
|
||||
if arg_memory_kinds is not None:
|
||||
ir_arg_memory_kinds = util.flatten(
|
||||
[[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)])
|
||||
[[mk] * len_ir_types(types)
|
||||
for mk, types in zip(arg_memory_kinds, input_types)])
|
||||
|
||||
ir_arg_layouts = None
|
||||
if arg_layouts is not None:
|
||||
@ -1459,13 +1476,18 @@ def lower_jaxpr_to_fun(
|
||||
ir_donated_args = None
|
||||
if xla_donated_args is not None:
|
||||
ir_donated_args = util.flatten(
|
||||
[[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)])
|
||||
[[is_donated] * len_ir_types(types)
|
||||
for is_donated, types in zip(xla_donated_args, input_types)])
|
||||
|
||||
ir_result_shardings = None
|
||||
unconstrained_shardings = None
|
||||
if result_shardings is not None:
|
||||
ir_result_shardings = util.flatten(
|
||||
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
|
||||
for a, s, types in zip(output_avals, result_shardings, output_types)])
|
||||
unconstrained_shardings = util.flatten(
|
||||
[[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
|
||||
for a, s, types in zip(output_avals, result_shardings, output_types)])
|
||||
|
||||
ir_result_memory_kinds = None
|
||||
custom_call_ir_result_memory_kinds = None
|
||||
@ -1580,8 +1602,9 @@ def lower_jaxpr_to_fun(
|
||||
attrs['jax.result_info'] = ir.StringAttr.get(name_)
|
||||
|
||||
if use_sharding_annotations and ir_result_shardings is not None:
|
||||
for attrs, sharding in zip(result_attrs, ir_result_shardings):
|
||||
if sharding is not None:
|
||||
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
|
||||
unconstrained_shardings): # type: ignore
|
||||
if sharding is not None and not us[0]:
|
||||
if config.use_shardy_partitioner.value:
|
||||
attrs["sdy.sharding"] = get_sharding_attr(sharding)
|
||||
else:
|
||||
@ -1658,6 +1681,15 @@ def lower_jaxpr_to_fun(
|
||||
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
||||
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
||||
|
||||
if ir_result_shardings is not None:
|
||||
flat_outputs = [
|
||||
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
|
||||
unspecified_dims=us[2])
|
||||
if us[0] and not us[1] else o
|
||||
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_shardings) # type: ignore
|
||||
]
|
||||
|
||||
# Insert a custom call if output is on host because XLA needs that to do the
|
||||
# transfer.
|
||||
if custom_call_ir_result_memory_kinds is not None and name == "main":
|
||||
|
@ -62,7 +62,7 @@ from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.sharding_impls import (
|
||||
@ -2162,10 +2162,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
|
||||
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
# Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported
|
||||
# in out_shardings at top level jit.
|
||||
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
|
||||
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
|
||||
a.sharding.spec))
|
||||
else:
|
||||
@ -2794,6 +2791,11 @@ def _maybe_get_and_check_out_shardings(
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
elif mlir.contains_unconstrained(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error
|
||||
@ -2909,8 +2911,9 @@ class UnloadedMeshExecutable:
|
||||
|
||||
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
|
||||
for i in in_shardings)
|
||||
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
|
||||
for o in out_shardings)
|
||||
allow_prop_to_outputs = tuple(
|
||||
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
|
||||
for o in out_shardings)
|
||||
|
||||
mesh = None
|
||||
if auto_spmd_lowering:
|
||||
|
@ -353,6 +353,22 @@ class Mesh(contextlib.ContextDecorator):
|
||||
def with_axis_types(self, new_axis_types) -> Mesh:
|
||||
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_collective(self) -> bool:
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
|
||||
|
@ -2675,13 +2675,6 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
|
||||
|
||||
# -------------------- sharding_cast ---------------------------
|
||||
|
||||
def _check_mesh_shape_same(src_sharding, dst_sharding, aval):
|
||||
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
|
||||
raise ValueError(
|
||||
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
|
||||
' match the mesh shape of the target sharding'
|
||||
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
|
||||
|
||||
def sharding_cast(xs, shardings):
|
||||
if isinstance(shardings, NamedSharding):
|
||||
return tree_map(lambda x: sharding_cast_p.bind(
|
||||
@ -2695,17 +2688,17 @@ def sharding_cast(xs, shardings):
|
||||
|
||||
sharding_cast_p = core.Primitive('sharding_cast')
|
||||
def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding):
|
||||
_check_mesh_shape_same(src_sharding, dst_sharding, aval)
|
||||
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
|
||||
raise ValueError(
|
||||
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
|
||||
' match the mesh shape of the target sharding'
|
||||
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
|
||||
return aval.update(sharding=dst_sharding)
|
||||
sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval)
|
||||
|
||||
def _sharding_cast_impl(x, src_sharding, dst_sharding):
|
||||
aval = shaped_abstractify(x)
|
||||
_check_mesh_shape_same(x.sharding, dst_sharding, aval)
|
||||
new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types)
|
||||
concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec)
|
||||
# TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)`
|
||||
return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x)
|
||||
return dispatch.apply_primitive(sharding_cast_p, x, src_sharding=src_sharding,
|
||||
dst_sharding=dst_sharding)
|
||||
sharding_cast_p.def_impl(_sharding_cast_impl)
|
||||
|
||||
def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):
|
||||
|
@ -4680,6 +4680,34 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
|
||||
lowered3.compile()
|
||||
|
||||
def test_jit_out_shardings_unconstrained(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
|
||||
@partial(jax.jit, out_shardings=out_s)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
|
||||
@partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y')))
|
||||
def g(x):
|
||||
return x * 3
|
||||
|
||||
out = g(arr)
|
||||
self.assertArraysEqual(out, np_inp * 3)
|
||||
self.assertEqual(out.sharding, s)
|
||||
lowered_text = g.lower(arr).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('<@mesh, [{?}, {"y"}]>', lowered_text)
|
||||
else:
|
||||
self.assertIn("unspecified_dims=[0]", lowered_text)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
@ -5548,7 +5576,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
return a
|
||||
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
|
||||
def test_auto_user(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
|
Loading…
x
Reference in New Issue
Block a user