mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add test for DeviceConstant repr
This commit is contained in:
parent
e3d4213e6d
commit
b1fd8e6eb6
@ -3646,7 +3646,7 @@ class _IotaConstant(xla.DeviceConstant):
|
||||
|
||||
def __init__(self, dtype, shape, axis):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.ndim = len(shape)
|
||||
self.size = prod(shape)
|
||||
self._npy_value = None
|
||||
@ -3675,7 +3675,7 @@ class _EyeConstant(xla.DeviceConstant):
|
||||
|
||||
def __init__(self, shape, axes, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.ndim = len(shape)
|
||||
self.size = prod(shape)
|
||||
self._npy_value = None
|
||||
|
@ -1343,6 +1343,9 @@ class DeviceConstantTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(argument_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(jit_result, expected, check_dtypes=True)
|
||||
|
||||
# ensure repr doesn't crash
|
||||
repr(make_const())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_fill={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype) if dtype else shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user