add tests for device constants

This commit is contained in:
Matthew Johnson 2018-12-18 22:45:34 -08:00
parent 20ca0bd733
commit a18e3f27ac
2 changed files with 72 additions and 2 deletions

View File

@ -417,14 +417,14 @@ opaque_param_ids = itertools.count()
def tie_in(x, y):
return tie_in_p.bind(x, y)
def full(shape, fill_value, dtype=None):
def full(shape, fill_value, dtype):
if onp.shape(fill_value):
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
raise TypeError(msg.format(onp.shape(fill_value)))
dtype = xla_bridge.canonicalize_dtype(dtype)
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
# create a FilledConstant value, otherwise just call broadcast.
dtype = dtype and xla_bridge.canonicalize_dtype(dtype)
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
return FilledConstant(onp.asarray(fill_value, dtype), shape)
elif isinstance(fill_value, xla.DeviceValue):

View File

@ -1346,6 +1346,76 @@ class LaxTest(jtu.JaxTestCase):
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
class DeviceConstantTest(jtu.JaxTestCase):
def _CheckDeviceConstant(self, make_const, expected):
# check casting to ndarray works
asarray_result = onp.asarray(make_const())
# check passing as an argument works (should hit constant handler)
zero = onp.array(0, expected.dtype)
argument_result = lax.add(zero, make_const())
# check looping into a compiled computation works
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
# ensure they're all the same
self.assertAllClose(asarray_result, expected, check_dtypes=True)
self.assertAllClose(argument_result, expected, check_dtypes=True)
self.assertAllClose(jit_result, expected, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_fill={}".format(
jtu.format_shape_dtype_string(shape, dtype), fill_value),
"shape": shape, "dtype": dtype, "fill_value": fill_value}
# for dtype in itertools.chain(all_dtypes, [None])
for dtype in [None]
for shape in [(), (3,), (2, 3), (2, 3, 4)]
for fill_value in [0, 1, onp.pi]))
def testFilledConstant(self, shape, fill_value, dtype):
make_const = lambda: lax.full(shape, fill_value, dtype)
expected = onp.full(shape, fill_value, xla_bridge.canonicalize_dtype(dtype))
self._CheckDeviceConstant(make_const, expected)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_dim={}".format(
jtu.format_shape_dtype_string(shape, dtype), dimension),
"shape": shape, "dtype": dtype, "dimension": dimension}
for dtype in default_dtypes
for shape in [(), (3,), (2, 3), (2, 3, 4)]
for dimension in range(len(shape))))
def testIotaConstant(self, dtype, shape, dimension):
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
arr = onp.arange(shape[dimension], dtype=xla_bridge.canonicalize_dtype(dtype))
singleton_shape = [1] * len(shape)
singleton_shape[dimension] = shape[dimension]
expected = onp.broadcast_to(arr.reshape(singleton_shape), shape)
self._CheckDeviceConstant(make_const, expected)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), axes),
"shape": shape, "dtype": dtype, "axes": axes}
for dtype in default_dtypes
for shape, axes in [
[(2, 3), (0, 1)],
[(2, 3, 4), (0, 1)],
[(2, 3, 4), (0, 2)],
[(2, 3, 4), (1, 2)],
[(2, 3, 4), (0, 1, 2)],
[(2, 3, 4, 2), (0, 1, 2)],
[(2, 3, 4, 2), (0, 2, 3)],
]))
def testEyeConstant(self, dtype, shape, axes):
make_const = lambda: lax.broadcasted_eye(dtype, shape, axes)
# don't check the asarray case, just assume it's right
expected = onp.asarray(make_const())
self._CheckDeviceConstant(make_const, expected)
GradTestSpec = collections.namedtuple(
"GradTestSpec", ["op", "nargs", "order", "rng", "dtypes"])