mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
validate shape & dtype in ShapeDtypeStruct
This commit is contained in:
parent
f624b6e751
commit
4b7e72c218
@ -2988,7 +2988,9 @@ def device_get(x: Any):
|
||||
class ShapeDtypeStruct:
|
||||
__slots__ = ["shape", "dtype", "named_shape", "sharding"]
|
||||
def __init__(self, shape, dtype, named_shape=None, sharding=None):
|
||||
self.shape = shape
|
||||
self.shape = tuple(shape)
|
||||
if dtype is None:
|
||||
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
|
||||
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
|
||||
if sharding is not None:
|
||||
self.sharding = sharding
|
||||
|
@ -2252,6 +2252,14 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(hash(s1), hash(s2))
|
||||
self.assertNotEqual(hash(s1), hash(s3))
|
||||
|
||||
def test_shape_dtype_struct_invalid_shape(self):
|
||||
with self.assertRaisesRegex(TypeError, "'int' object is not iterable"):
|
||||
api.ShapeDtypeStruct(shape=4, dtype='float32')
|
||||
|
||||
def test_shape_dtype_struct_dtype_none(self):
|
||||
with self.assertRaisesRegex(ValueError, "dtype must be specified"):
|
||||
api.ShapeDtypeStruct(shape=(), dtype=None)
|
||||
|
||||
def test_eval_shape(self):
|
||||
def fun(x, y):
|
||||
return jnp.tanh(jnp.dot(x, y) + 3.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user