mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
This commit is contained in:
parent
66f526894f
commit
8ef41a6e14
@ -1029,7 +1029,8 @@ xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
def _get_aval_array(self):
|
||||
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
|
||||
return self.aval.update(sharding=NamedSharding(
|
||||
self.sharding.mesh.abstract_mesh, self.sharding.spec))
|
||||
self.sharding.mesh.abstract_mesh,
|
||||
self.sharding.normalized_spec(self.ndim)))
|
||||
else:
|
||||
return self.aval
|
||||
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
|
@ -2073,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals):
|
||||
msg = '{}: arrays must have same number of dimensions, got {}.'
|
||||
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
||||
|
||||
specs = [a.sharding.normalized_spec for a in avals if a.shape]
|
||||
specs = [a.sharding.spec for a in avals if a.shape]
|
||||
|
||||
mesh = None
|
||||
for a in avals:
|
||||
|
@ -307,9 +307,8 @@ class NamedSharding(sharding.Sharding):
|
||||
def with_memory_kind(self, kind: str) -> NamedSharding:
|
||||
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
|
||||
|
||||
@functools.cached_property
|
||||
def normalized_spec(self):
|
||||
out = []
|
||||
def normalized_spec(self, ndim: int) -> PartitionSpec:
|
||||
out = [] # type: ignore
|
||||
for p in self._parsed_pspec:
|
||||
if p is None:
|
||||
raise ValueError("UNCONSTRAINED is not supported yet.")
|
||||
@ -319,7 +318,9 @@ class NamedSharding(sharding.Sharding):
|
||||
out.append(p[0])
|
||||
else:
|
||||
out.append(p)
|
||||
return tuple(out)
|
||||
if len(out) < ndim:
|
||||
out.extend([None] * (ndim - len(out)))
|
||||
return PartitionSpec(*out)
|
||||
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
@ -4652,10 +4652,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp1 * np_inp2))
|
||||
|
||||
out = f(arr1, arr1)
|
||||
out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',)))))
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp1 * np_inp1))
|
||||
|
||||
out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P())))
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp1 * np_inp2))
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
return x * y
|
||||
@ -4664,6 +4668,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
TypeError, "mul got incompatible shardings for broadcasting"):
|
||||
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x'))))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "mul got incompatible shardings for broadcasting"):
|
||||
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y')))))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user