mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
This commit is contained in:
parent
85654ceeab
commit
cb9a9952fe
@ -3079,6 +3079,10 @@ class ShapeDtypeStruct:
|
||||
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
|
||||
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
|
||||
if sharding is not None:
|
||||
if not isinstance(sharding, jax.sharding.Sharding):
|
||||
raise ValueError(
|
||||
"sharding should be an instance of `jax.sharding.Sharding`. "
|
||||
f"Got {sharding} of type {type(sharding)}.")
|
||||
self.sharding = sharding
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
@ -3093,7 +3097,9 @@ class ShapeDtypeStruct:
|
||||
|
||||
def __repr__(self):
|
||||
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
|
||||
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name}{ns})"
|
||||
sh = f", sharding={self.sharding}" if hasattr(self, "sharding") else ""
|
||||
return (f"{type(self).__name__}(shape={self.shape}, "
|
||||
f"dtype={self.dtype.name}{ns}{sh})")
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -3101,14 +3107,17 @@ class ShapeDtypeStruct:
|
||||
if not isinstance(other, ShapeDtypeStruct):
|
||||
return False
|
||||
else:
|
||||
return (other.shape, other.dtype, other.named_shape) == (
|
||||
self.shape, self.dtype, self.named_shape)
|
||||
other_sh = other.sharding if hasattr(other, "sharding") else None
|
||||
sh = self.sharding if hasattr(self, "sharding") else None
|
||||
return ((other.shape, other.dtype, other.named_shape, other_sh) ==
|
||||
(self.shape, self.dtype, self.named_shape, sh))
|
||||
|
||||
def __hash__(self):
|
||||
# TODO(frostig): avoid the conversion from dict by addressing
|
||||
# https://github.com/google/jax/issues/8182
|
||||
named = frozenset(self.named_shape.items())
|
||||
return hash((self.shape, self.dtype, named))
|
||||
sh = self.sharding if hasattr(self, "sharding") else None
|
||||
return hash((self.shape, self.dtype, named, sh))
|
||||
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = (
|
||||
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype),
|
||||
|
@ -4153,6 +4153,13 @@ class APITest(jtu.JaxTestCase):
|
||||
check=True, capture_output=True)
|
||||
assert expected in result.stderr.decode()
|
||||
|
||||
def test_shapedtypestruct_sharding_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"sharding should be an instance of `jax.sharding.Sharding`."):
|
||||
jax.ShapeDtypeStruct((8, 2), np.float32,
|
||||
sharding=jax.sharding.PartitionSpec('x'))
|
||||
|
||||
|
||||
@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
|
||||
class SubcallTraceCacheTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user