mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
AOT sharding mismatch error shouldn't have GSPMDSharding in it.
PiperOrigin-RevId: 576668290
This commit is contained in:
parent
ba9fd7744e
commit
4d15375596
@ -2921,22 +2921,25 @@ def check_gda_or_array_xla_sharding_match(
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
|
||||
db_xs = check_device_backend_on_shardings([xs])
|
||||
if not db_xs:
|
||||
xs = getattr(xs, '_original_sharding', xs)
|
||||
|
||||
# Raise memory kind mismatch error even if the arg is uncommitted.
|
||||
if arg.sharding.memory_kind != xs.memory_kind:
|
||||
errors.append(
|
||||
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
||||
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||
"Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
||||
|
||||
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||
# for AOT compiled call.
|
||||
if (not check_device_backend_on_shardings([xs]) and
|
||||
arg._committed and
|
||||
if (not db_xs and arg._committed and
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
arg.sharding._to_xla_hlo_sharding(arg.ndim),
|
||||
xs._to_xla_hlo_sharding(arg.ndim))):
|
||||
errors.append(
|
||||
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
||||
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||
"Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
||||
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors[:num_errors])
|
||||
@ -2944,7 +2947,8 @@ def check_gda_or_array_xla_sharding_match(
|
||||
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
raise ValueError(
|
||||
"Array(s) sharding does not match the input(s) sharding. "
|
||||
"Compiled object called with input sharding(s) does not match the "
|
||||
"sharding(s) the computation was compiled with. "
|
||||
f"Here are {num_mismatch_str}:\n{str_errors}")
|
||||
|
||||
|
||||
|
@ -1491,8 +1491,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
input_data)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) "
|
||||
r"sharding.*\n.*for arg x"):
|
||||
r"Compiled object called with input sharding\(s\) does not match the "
|
||||
r"sharding\(s\) the computation was compiled with.*\n.*for arg x"):
|
||||
compiled(arr)
|
||||
|
||||
def test_gda_auto_shardings_len(self):
|
||||
@ -1806,7 +1806,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) sharding. "
|
||||
r"Compiled object called with input sharding\(s\) does not match the "
|
||||
r"sharding\(s\) the computation was compiled with. "
|
||||
"Here are 5 mismatches out of 6"):
|
||||
compiled(a2, a2, a2, a2, a2, a2)
|
||||
|
||||
@ -1819,7 +1820,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
inp2 = {'x': a2, 'y': {'y1': a2}}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Array\(s\) sharding does not match the input\(s\) sharding. "
|
||||
r"Compiled object called with input sharding\(s\) does not match the "
|
||||
r"sharding\(s\) the computation was compiled with. "
|
||||
"Here are the 2 mismatches"):
|
||||
compiled(inp2)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user