From e615e2acb38410f35f8873a79bf08ce90f681a0d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 13 Mar 2025 15:22:35 -0700 Subject: [PATCH] Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule. Previously it was: `ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x` Now it is: `TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]` PiperOrigin-RevId: 736657644 --- jax/_src/core.py | 21 ++++++++++----------- jax/_src/lax/utils.py | 35 +++++++++++++++++++++++++++-------- jax/_src/named_sharding.py | 26 +++++++++++++++++++------- tests/pjit_test.py | 13 +++++++------ 4 files changed, 63 insertions(+), 32 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 46ff1ce97..a5472f01d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1894,6 +1894,13 @@ def get_sharding(sharding, shape): _check_divisibility(out_s, shape) return out_s +def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, + mesh_axis_types=False) -> str: + dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name + dt_str = dt_str.replace('void', 'float0') + shapestr = _get_shape_sharding_str(shape, spec) + mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' + return f'{dt_str}[{shapestr}]{mesh_axes}' class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent @@ -1954,17 +1961,9 @@ class ShapedArray(UnshapedArray): varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): - dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else - self.dtype.name) - dt_str = dt_str.replace('void', 'float0') - if self.sharding is not None: - shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) - mesh_axes = (f'({self.sharding.mesh._axis_types_dict})' - if mesh_axis_types else '') - return f'{dt_str}[{shapestr}]{mesh_axes}' - else: - shapestr = ','.join(map(str, self.shape)) - return f'{dt_str}[{shapestr}]' + return str_short_aval( + self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, + short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 32ccd00b5..f39d925ac 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -24,7 +24,7 @@ from jax._src import dtypes from jax._src import mesh as mesh_lib from jax._src.util import safe_zip from jax._src.partition_spec import PartitionSpec as P -from jax._src.named_sharding import NamedSharding +from jax._src.named_sharding import NamedSharding, DuplicateSpecError zip, unsafe_zip = safe_zip, zip @@ -81,6 +81,26 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') return rule(*avals, **kwargs) +def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, + multi_out, *avals, **kwargs): + out_shapes = shape_rule(*avals, **kwargs) + out_dtypes = dtype_rule(*avals, **kwargs) + num_out = len(out_shapes) if multi_out else None + try: + out_shardings = call_sharding_rule( + prim, sharding_rule, num_out, *avals, **kwargs) + except DuplicateSpecError as e: + if multi_out: + raise + avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) + mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh + out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, + short_dtypes=True) + raise TypeError( + f'{prim} operation with inputs: {avals_str} produces an illegally' + f' sharded result: {out_aval_str}') from e + return out_shapes, out_dtypes, out_shardings + def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals @@ -89,10 +109,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, least_specialized = type(max(avals, key=_get_array_abstraction_level)) if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) + out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( + prim, shape_rule, dtype_rule, sharding_rule, False, + *avals, **kwargs) out_aval = core.ShapedArray( - shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), - weak_type=weak_type, - sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs)) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) core.check_avals_context_mesh([out_aval], prim.name) return out_aval elif least_specialized is core.DShapedArray: @@ -113,11 +134,9 @@ def standard_multi_result_abstract_eval( least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ShapedArray: - out_shapes = shape_rule(*avals, **kwargs) - out_dtypes = dtype_rule(*avals, **kwargs) core.check_avals_context_mesh(avals, prim.name) - out_shardings = call_sharding_rule( - prim, sharding_rule, len(out_shapes), *avals, **kwargs) + out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( + prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index b7e7539af..f05e83f08 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -499,16 +499,26 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): spec = PartitionSpec() if spec is None else spec parsed_pspec = ParsedPartitionSpec.from_user_input( spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec") + _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) return parsed_pspec def check_pspec(mesh, spec, _manual_axes=frozenset()): - _check_unique_resources(spec, "NamedSharding spec") + _check_unique_resources(spec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, spec, _manual_axes) +class DuplicateSpecError(Exception): + def __init__(self, message, mesh, pspec): + super().__init__(message) + self.message = message + self.mesh = mesh + self.pspec = pspec + + def __str__(self): + return f"{self.message}" + def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str + pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False @@ -525,10 +535,12 @@ def _check_unique_resources( resource_counts[resource] = count + 1 if duplicate: multiple_uses = [r for r, c in resource_counts.items() if c > 1] - raise ValueError( - f'A single {arg_name} specification can map every mesh axis to at' - f' most one positional dimension, but {pspec} has duplicate entries' - f' for {mesh_lib.show_axes(multiple_uses)}') + raise DuplicateSpecError( + message=( + f'A single {arg_name} specification can map every mesh axis to at' + f' most one positional dimension, but {pspec} has duplicate entries' + f' for {mesh_lib.show_axes(multiple_uses)}'), + mesh=mesh, pspec=pspec) @cache(max_size=128, trace_context_in_key=False) def _check_mesh_resource_axis(mesh, pspec, _manual_axes): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b09d8bc49..5d4e1939d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -55,6 +55,7 @@ from jax._src.sharding_impls import ( SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, use_auto_axes, use_explicit_axes, reshard) +from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisTypes from jax._src.interpreters import pxla @@ -5055,7 +5056,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), - "PartitionSpec.*x.*x.*has duplicate entries", ValueError), + "dot_general operation.*produces an illegally sharded result", + TypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", TypeError), @@ -6396,9 +6398,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): # Errors out on the intermediate einsum: `bthj,bthD->bthjD` # because of a conflict with self.assertRaisesRegex( - ValueError, - 'A single NamedSharding spec specification can map every mesh axis to' - ' at most one positional dimension'): + TypeError, + 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) @jtu.with_user_mesh((2, 2), ('x', 'y'), @@ -7227,7 +7228,7 @@ class PJitErrorTest(jtu.JaxTestCase): error = (r"A single in_shardings specification can map every mesh " r"axis to at most one positional dimension, but " + spec_regex(spec) + " has duplicate entries for `x`") - with self.assertRaisesRegex(ValueError, error): + with self.assertRaisesRegex(DuplicateSpecError, error): pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x) @jtu.with_mesh([('x', 2), ('y', 1)]) @@ -7237,7 +7238,7 @@ class PJitErrorTest(jtu.JaxTestCase): error = (r"A single out_shardings specification can map every mesh " r"axis to at most one positional dimension, but " + spec_regex(spec) + " has duplicate entries for `x`") - with self.assertRaisesRegex(ValueError, error): + with self.assertRaisesRegex(DuplicateSpecError, error): pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x) def testEmptyMesh(self):