Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.

Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
This commit is contained in:
Yash Katariya 2025-03-13 15:22:35 -07:00 committed by jax authors
parent 1507754408
commit e615e2acb3
4 changed files with 63 additions and 32 deletions

View File

@ -1894,6 +1894,13 @@ def get_sharding(sharding, shape):
_check_divisibility(out_s, shape)
return out_s
def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False,
mesh_axis_types=False) -> str:
dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name
dt_str = dt_str.replace('void', 'float0')
shapestr = _get_shape_sharding_str(shape, spec)
mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else ''
return f'{dt_str}[{shapestr}]{mesh_axes}'
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
@ -1954,17 +1961,9 @@ class ShapedArray(UnshapedArray):
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
def str_short(self, short_dtypes=False, mesh_axis_types=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
mesh_axes = (f'({self.sharding.mesh._axis_types_dict})'
if mesh_axis_types else '')
return f'{dt_str}[{shapestr}]{mesh_axes}'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
return str_short_aval(
self.shape, self.dtype, self.sharding.mesh, self.sharding.spec,
short_dtypes, mesh_axis_types)
def _len(self, ignored_tracer):
try:

View File

@ -24,7 +24,7 @@ from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip
from jax._src.partition_spec import PartitionSpec as P
from jax._src.named_sharding import NamedSharding
from jax._src.named_sharding import NamedSharding, DuplicateSpecError
zip, unsafe_zip = safe_zip, zip
@ -81,6 +81,26 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule,
multi_out, *avals, **kwargs):
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
num_out = len(out_shapes) if multi_out else None
try:
out_shardings = call_sharding_rule(
prim, sharding_rule, num_out, *avals, **kwargs)
except DuplicateSpecError as e:
if multi_out:
raise
avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals)
mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh
out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec,
short_dtypes=True)
raise TypeError(
f'{prim} operation with inputs: {avals_str} produces an illegally'
f' sharded result: {out_aval_str}') from e
return out_shapes, out_dtypes, out_shardings
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
sharding_rule, *avals, **kwargs):
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
@ -89,10 +109,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ShapedArray:
core.check_avals_context_mesh(avals, prim.name)
out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule(
prim, shape_rule, dtype_rule, sharding_rule, False,
*avals, **kwargs)
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs))
out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding)
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
elif least_specialized is core.DShapedArray:
@ -113,11 +134,9 @@ def standard_multi_result_abstract_eval(
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
weak_types = weak_type_rule(*avals, **kwargs)
if least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = call_sharding_rule(
prim, sharding_rule, len(out_shapes), *avals, **kwargs)
out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule(
prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs)
if isinstance(weak_types, bool):
weak_types = (weak_types,) * len(out_shapes)
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)

View File

@ -499,16 +499,26 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
spec = PartitionSpec() if spec is None else spec
parsed_pspec = ParsedPartitionSpec.from_user_input(
spec, "NamedSharding spec", allow_unconstrained_dims=True)
_check_unique_resources(parsed_pspec, "NamedSharding spec")
_check_unique_resources(parsed_pspec, "NamedSharding spec", mesh)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec
def check_pspec(mesh, spec, _manual_axes=frozenset()):
_check_unique_resources(spec, "NamedSharding spec")
_check_unique_resources(spec, "NamedSharding spec", mesh)
_check_mesh_resource_axis(mesh, spec, _manual_axes)
class DuplicateSpecError(Exception):
def __init__(self, message, mesh, pspec):
super().__init__(message)
self.message = message
self.mesh = mesh
self.pspec = pspec
def __str__(self):
return f"{self.message}"
def _check_unique_resources(
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None,
) -> None:
resource_counts: dict[MeshAxisName, int] = {}
duplicate = False
@ -525,10 +535,12 @@ def _check_unique_resources(
resource_counts[resource] = count + 1
if duplicate:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
raise ValueError(
f'A single {arg_name} specification can map every mesh axis to at'
f' most one positional dimension, but {pspec} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}')
raise DuplicateSpecError(
message=(
f'A single {arg_name} specification can map every mesh axis to at'
f' most one positional dimension, but {pspec} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}'),
mesh=mesh, pspec=pspec)
@cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):

View File

@ -55,6 +55,7 @@ from jax._src.sharding_impls import (
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes,
use_auto_axes, use_explicit_axes, reshard)
from jax._src.named_sharding import DuplicateSpecError
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisTypes
from jax._src.interpreters import pxla
@ -5055,7 +5056,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@parameterized.named_parameters(
('fail1', P('x', None), P(None, 'x'),
"PartitionSpec.*x.*x.*has duplicate entries", ValueError),
"dot_general operation.*produces an illegally sharded result",
TypeError),
('fail2', P('x', 'y'), P('x', 'y'),
"dot_general requires contracting dimensions to have consistent sharding",
TypeError),
@ -6396,9 +6398,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
# Errors out on the intermediate einsum: `bthj,bthD->bthjD`
# because of a conflict
with self.assertRaisesRegex(
ValueError,
'A single NamedSharding spec specification can map every mesh axis to'
' at most one positional dimension'):
TypeError,
'dot_general operation.*produces an illegally sharded result'):
f(arr1, arr2, arr3)
@jtu.with_user_mesh((2, 2), ('x', 'y'),
@ -7227,7 +7228,7 @@ class PJitErrorTest(jtu.JaxTestCase):
error = (r"A single in_shardings specification can map every mesh "
r"axis to at most one positional dimension, but " +
spec_regex(spec) + " has duplicate entries for `x`")
with self.assertRaisesRegex(ValueError, error):
with self.assertRaisesRegex(DuplicateSpecError, error):
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
@jtu.with_mesh([('x', 2), ('y', 1)])
@ -7237,7 +7238,7 @@ class PJitErrorTest(jtu.JaxTestCase):
error = (r"A single out_shardings specification can map every mesh "
r"axis to at most one positional dimension, but " +
spec_regex(spec) + " has duplicate entries for `x`")
with self.assertRaisesRegex(ValueError, error):
with self.assertRaisesRegex(DuplicateSpecError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
def testEmptyMesh(self):