mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove the cached check in aot compiled call in MeshExecutable because a fast C++ dispatch path exists. This leads to a better error message which contains the shape and arg value.
PiperOrigin-RevId: 494815311
This commit is contained in:
parent
23001ae782
commit
d491d9fd3f
@ -3597,15 +3597,12 @@ class MeshExecutable(stages.XlaExecutable):
|
|||||||
return outs, fastpath_data
|
return outs, fastpath_data
|
||||||
|
|
||||||
if xc._version < 108:
|
if xc._version < 108:
|
||||||
|
|
||||||
def dummy():
|
def dummy():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
dummy.__name__ = self.unsafe_call.name
|
dummy.__name__ = self.unsafe_call.name
|
||||||
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
|
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
|
||||||
else:
|
else:
|
||||||
return xc._xla.pjit( # type: ignore
|
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
|
||||||
self.unsafe_call.name, aot_cache_miss, [])
|
|
||||||
|
|
||||||
|
|
||||||
def _out_shardings_for_trivial(
|
def _out_shardings_for_trivial(
|
||||||
@ -3723,24 +3720,27 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
|||||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||||
from jax._src.array import ArrayImpl
|
from jax._src.array import ArrayImpl
|
||||||
|
|
||||||
@lru_cache(maxsize=4096)
|
|
||||||
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
|
|
||||||
if committed and not are_op_shardings_equal(
|
|
||||||
arg_sharding._to_xla_op_sharding(ndim),
|
|
||||||
in_xla_sharding._to_xla_op_sharding(ndim)):
|
|
||||||
raise ValueError(
|
|
||||||
f"{arg_type} sharding does not match the input sharding. "
|
|
||||||
f"Got {arg_type} sharding: {arg_sharding} and "
|
|
||||||
f"xla sharding: {in_xla_sharding}")
|
|
||||||
|
|
||||||
for arg, xs in safe_zip(args, in_xla_shardings):
|
for arg, xs in safe_zip(args, in_xla_shardings):
|
||||||
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||||
continue
|
continue
|
||||||
if isinstance(arg, GlobalDeviceArray):
|
if isinstance(arg, GlobalDeviceArray):
|
||||||
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
|
arg_sharding = _create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||||
'GDA', arg.ndim, True)
|
arg_type = 'GDA'
|
||||||
|
committed = True
|
||||||
else:
|
else:
|
||||||
_cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed)
|
arg_sharding = arg.sharding
|
||||||
|
arg_type = 'Array'
|
||||||
|
committed = arg._committed
|
||||||
|
|
||||||
|
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||||
|
# for AOT compiled call.
|
||||||
|
if committed and not are_op_shardings_equal(
|
||||||
|
arg_sharding._to_xla_op_sharding(arg.ndim),
|
||||||
|
xs._to_xla_op_sharding(arg.ndim)):
|
||||||
|
raise ValueError(
|
||||||
|
f"{arg_type} sharding does not match the input sharding. "
|
||||||
|
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
|
||||||
|
f"arg shape: {arg.shape}, arg value: {arg}")
|
||||||
|
|
||||||
|
|
||||||
def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user