Remove the cached check in aot compiled call in MeshExecutable because a fast C++ dispatch path exists. This leads to a better error message which contains the shape and arg value.

PiperOrigin-RevId: 494815311
This commit is contained in:
Yash Katariya 2022-12-12 13:36:38 -08:00 committed by jax authors
parent 23001ae782
commit d491d9fd3f

View File

@ -3597,15 +3597,12 @@ class MeshExecutable(stages.XlaExecutable):
return outs, fastpath_data return outs, fastpath_data
if xc._version < 108: if xc._version < 108:
def dummy(): def dummy():
pass pass
dummy.__name__ = self.unsafe_call.name dummy.__name__ = self.unsafe_call.name
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
else: else:
return xc._xla.pjit( # type: ignore return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
self.unsafe_call.name, aot_cache_miss, [])
def _out_shardings_for_trivial( def _out_shardings_for_trivial(
@ -3723,24 +3720,27 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray from jax.experimental.global_device_array import GlobalDeviceArray
from jax._src.array import ArrayImpl from jax._src.array import ArrayImpl
@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(ndim),
in_xla_sharding._to_xla_op_sharding(ndim)):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and "
f"xla sharding: {in_xla_sharding}")
for arg, xs in safe_zip(args, in_xla_shardings): for arg, xs in safe_zip(args, in_xla_shardings):
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)): if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
continue continue
if isinstance(arg, GlobalDeviceArray): if isinstance(arg, GlobalDeviceArray):
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs, arg_sharding = _create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
'GDA', arg.ndim, True) arg_type = 'GDA'
committed = True
else: else:
_cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed) arg_sharding = arg.sharding
arg_type = 'Array'
committed = arg._committed
# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim)):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
f"arg shape: {arg.shape}, arg value: {arg}")
def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: