[sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.

PiperOrigin-RevId: 684465123
This commit is contained in:
Yash Katariya 2024-10-10 09:06:24 -07:00 committed by jax authors
parent 66f526894f
commit 8ef41a6e14
4 changed files with 17 additions and 7 deletions

View File

@ -1029,7 +1029,8 @@ xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh, self.sharding.spec))
self.sharding.mesh.abstract_mesh,
self.sharding.normalized_spec(self.ndim)))
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array

View File

@ -2073,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals):
msg = '{}: arrays must have same number of dimensions, got {}.'
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
specs = [a.sharding.normalized_spec for a in avals if a.shape]
specs = [a.sharding.spec for a in avals if a.shape]
mesh = None
for a in avals:

View File

@ -307,9 +307,8 @@ class NamedSharding(sharding.Sharding):
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
@functools.cached_property
def normalized_spec(self):
out = []
def normalized_spec(self, ndim: int) -> PartitionSpec:
out = [] # type: ignore
for p in self._parsed_pspec:
if p is None:
raise ValueError("UNCONSTRAINED is not supported yet.")
@ -319,7 +318,9 @@ class NamedSharding(sharding.Sharding):
out.append(p[0])
else:
out.append(p)
return tuple(out)
if len(out) < ndim:
out.extend([None] * (ndim - len(out)))
return PartitionSpec(*out)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)

View File

@ -4652,10 +4652,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp2))
out = f(arr1, arr1)
out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',)))))
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp1))
out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P())))
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp2))
@jax.jit
def g(x, y):
return x * y
@ -4664,6 +4668,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x'))))
with self.assertRaisesRegex(
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y')))))
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):