mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Remove global_str since all avals in pjit are global
PiperOrigin-RevId: 522443476
This commit is contained in:
parent
b4402185db
commit
038ac445c2
@ -965,7 +965,6 @@ def pjit_check_aval_sharding(
|
||||
for aval, s in zip(flat_avals, shardings):
|
||||
if _is_unspecified_or_auto(s):
|
||||
continue
|
||||
global_str = "" if s.is_fully_addressable else " global"
|
||||
shape = aval.shape
|
||||
try:
|
||||
# Sharding interfaces can implement `is_compatible_aval` as an optional
|
||||
@ -988,7 +987,7 @@ def pjit_check_aval_sharding(
|
||||
if not allow_uneven_sharding and shape[i] % size != 0:
|
||||
raise ValueError(f"One of {what_aval} was given the sharding "
|
||||
f"of {s}, which implies that "
|
||||
f"the{global_str} size of its dimension {i} should be "
|
||||
f"the global size of its dimension {i} should be "
|
||||
f"divisible by {size}, but it is equal to {shape[i]} "
|
||||
f"(full shape: {shape}) ")
|
||||
|
||||
|
@ -2792,13 +2792,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.ones((1,))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'implies that the size of its dimension 0 should be '
|
||||
ValueError, 'implies that the global size of its dimension 0 should be '
|
||||
'divisible by 2, but it is equal to 1 '):
|
||||
jax.device_put(x, s)
|
||||
|
||||
y = jnp.ones((2,))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'implies that the size of its dimension 0 should be '
|
||||
ValueError, 'implies that the global size of its dimension 0 should be '
|
||||
'divisible by 2, but it is equal to 1 '):
|
||||
jax.device_put((y, x), s)
|
||||
|
||||
@ -2963,7 +2963,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
|
||||
error = re.compile(
|
||||
r"One of pjit arguments.*" + spec_regex(spec) + r".*"
|
||||
r"implies that the size of its dimension 0 should be "
|
||||
r"implies that the global size of its dimension 0 should be "
|
||||
r"divisible by " + mesh_size + r", but it is equal to 3 "
|
||||
r"\(full shape: \(3, 2\)\)", re.M | re.S)
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
@ -2976,7 +2976,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
|
||||
error = re.compile(
|
||||
r"One of pjit outputs.*" + spec_regex(spec) + r".*"
|
||||
r"implies that the size of its dimension 0 should be "
|
||||
r"implies that the global size of its dimension 0 should be "
|
||||
r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S)
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=P(resources, None))(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user