mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add tests for device constants
This commit is contained in:
parent
20ca0bd733
commit
a18e3f27ac
@ -417,14 +417,14 @@ opaque_param_ids = itertools.count()
|
||||
def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
def full(shape, fill_value, dtype=None):
|
||||
def full(shape, fill_value, dtype):
|
||||
if onp.shape(fill_value):
|
||||
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
||||
raise TypeError(msg.format(onp.shape(fill_value)))
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
|
||||
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
|
||||
# create a FilledConstant value, otherwise just call broadcast.
|
||||
dtype = dtype and xla_bridge.canonicalize_dtype(dtype)
|
||||
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
|
||||
return FilledConstant(onp.asarray(fill_value, dtype), shape)
|
||||
elif isinstance(fill_value, xla.DeviceValue):
|
||||
|
@ -1346,6 +1346,76 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
class DeviceConstantTest(jtu.JaxTestCase):
|
||||
def _CheckDeviceConstant(self, make_const, expected):
|
||||
# check casting to ndarray works
|
||||
asarray_result = onp.asarray(make_const())
|
||||
|
||||
# check passing as an argument works (should hit constant handler)
|
||||
zero = onp.array(0, expected.dtype)
|
||||
argument_result = lax.add(zero, make_const())
|
||||
|
||||
# check looping into a compiled computation works
|
||||
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
|
||||
|
||||
# ensure they're all the same
|
||||
self.assertAllClose(asarray_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(argument_result, expected, check_dtypes=True)
|
||||
self.assertAllClose(jit_result, expected, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_fill={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), fill_value),
|
||||
"shape": shape, "dtype": dtype, "fill_value": fill_value}
|
||||
# for dtype in itertools.chain(all_dtypes, [None])
|
||||
for dtype in [None]
|
||||
for shape in [(), (3,), (2, 3), (2, 3, 4)]
|
||||
for fill_value in [0, 1, onp.pi]))
|
||||
def testFilledConstant(self, shape, fill_value, dtype):
|
||||
make_const = lambda: lax.full(shape, fill_value, dtype)
|
||||
expected = onp.full(shape, fill_value, xla_bridge.canonicalize_dtype(dtype))
|
||||
self._CheckDeviceConstant(make_const, expected)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_dim={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), dimension),
|
||||
"shape": shape, "dtype": dtype, "dimension": dimension}
|
||||
for dtype in default_dtypes
|
||||
for shape in [(), (3,), (2, 3), (2, 3, 4)]
|
||||
for dimension in range(len(shape))))
|
||||
def testIotaConstant(self, dtype, shape, dimension):
|
||||
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
|
||||
|
||||
arr = onp.arange(shape[dimension], dtype=xla_bridge.canonicalize_dtype(dtype))
|
||||
singleton_shape = [1] * len(shape)
|
||||
singleton_shape[dimension] = shape[dimension]
|
||||
expected = onp.broadcast_to(arr.reshape(singleton_shape), shape)
|
||||
|
||||
self._CheckDeviceConstant(make_const, expected)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axes),
|
||||
"shape": shape, "dtype": dtype, "axes": axes}
|
||||
for dtype in default_dtypes
|
||||
for shape, axes in [
|
||||
[(2, 3), (0, 1)],
|
||||
[(2, 3, 4), (0, 1)],
|
||||
[(2, 3, 4), (0, 2)],
|
||||
[(2, 3, 4), (1, 2)],
|
||||
[(2, 3, 4), (0, 1, 2)],
|
||||
[(2, 3, 4, 2), (0, 1, 2)],
|
||||
[(2, 3, 4, 2), (0, 2, 3)],
|
||||
]))
|
||||
def testEyeConstant(self, dtype, shape, axes):
|
||||
make_const = lambda: lax.broadcasted_eye(dtype, shape, axes)
|
||||
|
||||
# don't check the asarray case, just assume it's right
|
||||
expected = onp.asarray(make_const())
|
||||
|
||||
self._CheckDeviceConstant(make_const, expected)
|
||||
|
||||
|
||||
GradTestSpec = collections.namedtuple(
|
||||
"GradTestSpec", ["op", "nargs", "order", "rng", "dtypes"])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user