validate shape & dtype in ShapeDtypeStruct

This commit is contained in:
Jake VanderPlas 2022-12-29 14:29:54 -08:00
parent f624b6e751
commit 4b7e72c218
2 changed files with 11 additions and 1 deletions

View File

@ -2988,7 +2988,9 @@ def device_get(x: Any):
class ShapeDtypeStruct:
__slots__ = ["shape", "dtype", "named_shape", "sharding"]
def __init__(self, shape, dtype, named_shape=None, sharding=None):
self.shape = shape
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
if sharding is not None:
self.sharding = sharding

View File

@ -2252,6 +2252,14 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(hash(s1), hash(s2))
self.assertNotEqual(hash(s1), hash(s3))
def test_shape_dtype_struct_invalid_shape(self):
with self.assertRaisesRegex(TypeError, "'int' object is not iterable"):
api.ShapeDtypeStruct(shape=4, dtype='float32')
def test_shape_dtype_struct_dtype_none(self):
with self.assertRaisesRegex(ValueError, "dtype must be specified"):
api.ShapeDtypeStruct(shape=(), dtype=None)
def test_eval_shape(self):
def fun(x, y):
return jnp.tanh(jnp.dot(x, y) + 3.)