add test for DeviceConstant repr

This commit is contained in:
Matthew Johnson 2019-05-17 12:38:45 -07:00
parent e3d4213e6d
commit b1fd8e6eb6
2 changed files with 5 additions and 2 deletions

View File

@ -3646,7 +3646,7 @@ class _IotaConstant(xla.DeviceConstant):
def __init__(self, dtype, shape, axis):
self.shape = shape
self.dtype = dtype
self.dtype = onp.dtype(dtype)
self.ndim = len(shape)
self.size = prod(shape)
self._npy_value = None
@ -3675,7 +3675,7 @@ class _EyeConstant(xla.DeviceConstant):
def __init__(self, shape, axes, dtype):
self.shape = shape
self.dtype = dtype
self.dtype = onp.dtype(dtype)
self.ndim = len(shape)
self.size = prod(shape)
self._npy_value = None

View File

@ -1343,6 +1343,9 @@ class DeviceConstantTest(jtu.JaxTestCase):
self.assertAllClose(argument_result, expected, check_dtypes=True)
self.assertAllClose(jit_result, expected, check_dtypes=True)
# ensure repr doesn't crash
repr(make_const())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_fill={}".format(
jtu.format_shape_dtype_string(shape, dtype) if dtype else shape,