mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove the cached check in aot compiled call in MeshExecutable because a fast C++ dispatch path exists. This leads to a better error message which contains the shape and arg value.
PiperOrigin-RevId: 494815311
This commit is contained in:
parent
23001ae782
commit
d491d9fd3f
@ -3597,15 +3597,12 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return outs, fastpath_data
|
||||
|
||||
if xc._version < 108:
|
||||
|
||||
def dummy():
|
||||
pass
|
||||
|
||||
dummy.__name__ = self.unsafe_call.name
|
||||
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
|
||||
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit( # type: ignore
|
||||
self.unsafe_call.name, aot_cache_miss, [])
|
||||
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
|
||||
|
||||
|
||||
def _out_shardings_for_trivial(
|
||||
@ -3723,24 +3720,27 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
|
||||
if committed and not are_op_shardings_equal(
|
||||
arg_sharding._to_xla_op_sharding(ndim),
|
||||
in_xla_sharding._to_xla_op_sharding(ndim)):
|
||||
raise ValueError(
|
||||
f"{arg_type} sharding does not match the input sharding. "
|
||||
f"Got {arg_type} sharding: {arg_sharding} and "
|
||||
f"xla sharding: {in_xla_sharding}")
|
||||
|
||||
for arg, xs in safe_zip(args, in_xla_shardings):
|
||||
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||
continue
|
||||
if isinstance(arg, GlobalDeviceArray):
|
||||
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
|
||||
'GDA', arg.ndim, True)
|
||||
arg_sharding = _create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
arg_type = 'GDA'
|
||||
committed = True
|
||||
else:
|
||||
_cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed)
|
||||
arg_sharding = arg.sharding
|
||||
arg_type = 'Array'
|
||||
committed = arg._committed
|
||||
|
||||
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||
# for AOT compiled call.
|
||||
if committed and not are_op_shardings_equal(
|
||||
arg_sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim)):
|
||||
raise ValueError(
|
||||
f"{arg_type} sharding does not match the input sharding. "
|
||||
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
|
||||
f"arg shape: {arg.shape}, arg value: {arg}")
|
||||
|
||||
|
||||
def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
|
Loading…
x
Reference in New Issue
Block a user