Remove global_str since all avals in pjit are global

PiperOrigin-RevId: 522443476
This commit is contained in:
Yash Katariya 2023-04-06 14:51:30 -07:00 committed by jax authors
parent b4402185db
commit 038ac445c2
2 changed files with 5 additions and 6 deletions

View File

@ -965,7 +965,6 @@ def pjit_check_aval_sharding(
for aval, s in zip(flat_avals, shardings):
if _is_unspecified_or_auto(s):
continue
global_str = "" if s.is_fully_addressable else " global"
shape = aval.shape
try:
# Sharding interfaces can implement `is_compatible_aval` as an optional
@ -988,7 +987,7 @@ def pjit_check_aval_sharding(
if not allow_uneven_sharding and shape[i] % size != 0:
raise ValueError(f"One of {what_aval} was given the sharding "
f"of {s}, which implies that "
f"the{global_str} size of its dimension {i} should be "
f"the global size of its dimension {i} should be "
f"divisible by {size}, but it is equal to {shape[i]} "
f"(full shape: {shape}) ")

View File

@ -2792,13 +2792,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
x = jnp.ones((1,))
with self.assertRaisesRegex(
ValueError, 'implies that the size of its dimension 0 should be '
ValueError, 'implies that the global size of its dimension 0 should be '
'divisible by 2, but it is equal to 1 '):
jax.device_put(x, s)
y = jnp.ones((2,))
with self.assertRaisesRegex(
ValueError, 'implies that the size of its dimension 0 should be '
ValueError, 'implies that the global size of its dimension 0 should be '
'divisible by 2, but it is equal to 1 '):
jax.device_put((y, x), s)
@ -2963,7 +2963,7 @@ class PJitErrorTest(jtu.JaxTestCase):
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
error = re.compile(
r"One of pjit arguments.*" + spec_regex(spec) + r".*"
r"implies that the size of its dimension 0 should be "
r"implies that the global size of its dimension 0 should be "
r"divisible by " + mesh_size + r", but it is equal to 3 "
r"\(full shape: \(3, 2\)\)", re.M | re.S)
with self.assertRaisesRegex(ValueError, error):
@ -2976,7 +2976,7 @@ class PJitErrorTest(jtu.JaxTestCase):
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
error = re.compile(
r"One of pjit outputs.*" + spec_regex(spec) + r".*"
r"implies that the size of its dimension 0 should be "
r"implies that the global size of its dimension 0 should be "
r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S)
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=P(resources, None))(x)