mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was: `ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x` Now it is: `TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]` PiperOrigin-RevId: 736657644
This commit is contained in:
parent
1507754408
commit
e615e2acb3
@ -1894,6 +1894,13 @@ def get_sharding(sharding, shape):
|
||||
_check_divisibility(out_s, shape)
|
||||
return out_s
|
||||
|
||||
def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False,
|
||||
mesh_axis_types=False) -> str:
|
||||
dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
shapestr = _get_shape_sharding_str(shape, spec)
|
||||
mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else ''
|
||||
return f'{dt_str}[{shapestr}]{mesh_axes}'
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
|
||||
@ -1954,17 +1961,9 @@ class ShapedArray(UnshapedArray):
|
||||
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
|
||||
|
||||
def str_short(self, short_dtypes=False, mesh_axis_types=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
mesh_axes = (f'({self.sharding.mesh._axis_types_dict})'
|
||||
if mesh_axis_types else '')
|
||||
return f'{dt_str}[{shapestr}]{mesh_axes}'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
return str_short_aval(
|
||||
self.shape, self.dtype, self.sharding.mesh, self.sharding.spec,
|
||||
short_dtypes, mesh_axis_types)
|
||||
|
||||
def _len(self, ignored_tracer):
|
||||
try:
|
||||
|
@ -24,7 +24,7 @@ from jax._src import dtypes
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.named_sharding import NamedSharding
|
||||
from jax._src.named_sharding import NamedSharding, DuplicateSpecError
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
@ -81,6 +81,26 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`')
|
||||
return rule(*avals, **kwargs)
|
||||
|
||||
def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule,
|
||||
multi_out, *avals, **kwargs):
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
num_out = len(out_shapes) if multi_out else None
|
||||
try:
|
||||
out_shardings = call_sharding_rule(
|
||||
prim, sharding_rule, num_out, *avals, **kwargs)
|
||||
except DuplicateSpecError as e:
|
||||
if multi_out:
|
||||
raise
|
||||
avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals)
|
||||
mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh
|
||||
out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec,
|
||||
short_dtypes=True)
|
||||
raise TypeError(
|
||||
f'{prim} operation with inputs: {avals_str} produces an illegally'
|
||||
f' sharded result: {out_aval_str}') from e
|
||||
return out_shapes, out_dtypes, out_shardings
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
sharding_rule, *avals, **kwargs):
|
||||
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
||||
@ -89,10 +109,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
least_specialized = type(max(avals, key=_get_array_abstraction_level))
|
||||
if least_specialized is core.ShapedArray:
|
||||
core.check_avals_context_mesh(avals, prim.name)
|
||||
out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule(
|
||||
prim, shape_rule, dtype_rule, sharding_rule, False,
|
||||
*avals, **kwargs)
|
||||
out_aval = core.ShapedArray(
|
||||
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
|
||||
weak_type=weak_type,
|
||||
sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs))
|
||||
out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding)
|
||||
core.check_avals_context_mesh([out_aval], prim.name)
|
||||
return out_aval
|
||||
elif least_specialized is core.DShapedArray:
|
||||
@ -113,11 +134,9 @@ def standard_multi_result_abstract_eval(
|
||||
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
|
||||
weak_types = weak_type_rule(*avals, **kwargs)
|
||||
if least_specialized is core.ShapedArray:
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
core.check_avals_context_mesh(avals, prim.name)
|
||||
out_shardings = call_sharding_rule(
|
||||
prim, sharding_rule, len(out_shapes), *avals, **kwargs)
|
||||
out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule(
|
||||
prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs)
|
||||
if isinstance(weak_types, bool):
|
||||
weak_types = (weak_types,) * len(out_shapes)
|
||||
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
|
||||
|
@ -499,16 +499,26 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
spec = PartitionSpec() if spec is None else spec
|
||||
parsed_pspec = ParsedPartitionSpec.from_user_input(
|
||||
spec, "NamedSharding spec", allow_unconstrained_dims=True)
|
||||
_check_unique_resources(parsed_pspec, "NamedSharding spec")
|
||||
_check_unique_resources(parsed_pspec, "NamedSharding spec", mesh)
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
return parsed_pspec
|
||||
|
||||
def check_pspec(mesh, spec, _manual_axes=frozenset()):
|
||||
_check_unique_resources(spec, "NamedSharding spec")
|
||||
_check_unique_resources(spec, "NamedSharding spec", mesh)
|
||||
_check_mesh_resource_axis(mesh, spec, _manual_axes)
|
||||
|
||||
class DuplicateSpecError(Exception):
|
||||
def __init__(self, message, mesh, pspec):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.mesh = mesh
|
||||
self.pspec = pspec
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.message}"
|
||||
|
||||
def _check_unique_resources(
|
||||
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str
|
||||
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None,
|
||||
) -> None:
|
||||
resource_counts: dict[MeshAxisName, int] = {}
|
||||
duplicate = False
|
||||
@ -525,10 +535,12 @@ def _check_unique_resources(
|
||||
resource_counts[resource] = count + 1
|
||||
if duplicate:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
raise ValueError(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
f' most one positional dimension, but {pspec} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}')
|
||||
raise DuplicateSpecError(
|
||||
message=(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
f' most one positional dimension, but {pspec} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}'),
|
||||
mesh=mesh, pspec=pspec)
|
||||
|
||||
@cache(max_size=128, trace_context_in_key=False)
|
||||
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
|
||||
|
@ -55,6 +55,7 @@ from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, parse_flatten_op_sharding)
|
||||
from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes,
|
||||
use_auto_axes, use_explicit_axes, reshard)
|
||||
from jax._src.named_sharding import DuplicateSpecError
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import AxisTypes
|
||||
from jax._src.interpreters import pxla
|
||||
@ -5055,7 +5056,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('fail1', P('x', None), P(None, 'x'),
|
||||
"PartitionSpec.*x.*x.*has duplicate entries", ValueError),
|
||||
"dot_general operation.*produces an illegally sharded result",
|
||||
TypeError),
|
||||
('fail2', P('x', 'y'), P('x', 'y'),
|
||||
"dot_general requires contracting dimensions to have consistent sharding",
|
||||
TypeError),
|
||||
@ -6396,9 +6398,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
# Errors out on the intermediate einsum: `bthj,bthD->bthjD`
|
||||
# because of a conflict
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'A single NamedSharding spec specification can map every mesh axis to'
|
||||
' at most one positional dimension'):
|
||||
TypeError,
|
||||
'dot_general operation.*produces an illegally sharded result'):
|
||||
f(arr1, arr2, arr3)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
@ -7227,7 +7228,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
error = (r"A single in_shardings specification can map every mesh "
|
||||
r"axis to at most one positional dimension, but " +
|
||||
spec_regex(spec) + " has duplicate entries for `x`")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
with self.assertRaisesRegex(DuplicateSpecError, error):
|
||||
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
@ -7237,7 +7238,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
error = (r"A single out_shardings specification can map every mesh "
|
||||
r"axis to at most one positional dimension, but " +
|
||||
spec_regex(spec) + " has duplicate entries for `x`")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
with self.assertRaisesRegex(DuplicateSpecError, error):
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
|
||||
|
||||
def testEmptyMesh(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user