Check if the sharding input to ShapeDtypeStruct is an instance of Sharding

PiperOrigin-RevId: 502652848
This commit is contained in:
Yash Katariya 2023-01-17 12:08:06 -08:00 committed by jax authors
parent 85654ceeab
commit cb9a9952fe
2 changed files with 20 additions and 4 deletions

View File

@ -3079,6 +3079,10 @@ class ShapeDtypeStruct:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
if sharding is not None:
if not isinstance(sharding, jax.sharding.Sharding):
raise ValueError(
"sharding should be an instance of `jax.sharding.Sharding`. "
f"Got {sharding} of type {type(sharding)}.")
self.sharding = sharding
self.named_shape = {} if named_shape is None else dict(named_shape)
@ -3093,7 +3097,9 @@ class ShapeDtypeStruct:
def __repr__(self):
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name}{ns})"
sh = f", sharding={self.sharding}" if hasattr(self, "sharding") else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{ns}{sh})")
__str__ = __repr__
@ -3101,14 +3107,17 @@ class ShapeDtypeStruct:
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return (other.shape, other.dtype, other.named_shape) == (
self.shape, self.dtype, self.named_shape)
other_sh = other.sharding if hasattr(other, "sharding") else None
sh = self.sharding if hasattr(self, "sharding") else None
return ((other.shape, other.dtype, other.named_shape, other_sh) ==
(self.shape, self.dtype, self.named_shape, sh))
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named))
sh = self.sharding if hasattr(self, "sharding") else None
return hash((self.shape, self.dtype, named, sh))
core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype),

View File

@ -4153,6 +4153,13 @@ class APITest(jtu.JaxTestCase):
check=True, capture_output=True)
assert expected in result.stderr.decode()
def test_shapedtypestruct_sharding_error(self):
with self.assertRaisesRegex(
ValueError,
"sharding should be an instance of `jax.sharding.Sharding`."):
jax.ShapeDtypeStruct((8, 2), np.float32,
sharding=jax.sharding.PartitionSpec('x'))
@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
class SubcallTraceCacheTest(jtu.JaxTestCase):