mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add special value grad tests, sinh failing
This commit is contained in:
parent
1147eac3b3
commit
4225857166
@ -1399,6 +1399,18 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
api.jit(f)(1.) # doesn't crash
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
|
||||
class DeviceConstantTest(jtu.JaxTestCase):
|
||||
def _CheckDeviceConstant(self, make_const, expected):
|
||||
@ -1502,12 +1514,28 @@ LAX_GRAD_OPS = [
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.log1p, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.sinh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.cosh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.tanh, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.asinh, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.acosh, nargs=1, order=2, rng=jtu.rand_positive(),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.atanh, nargs=1, order=2, rng=jtu.rand_uniform(-0.9, 0.9),
|
||||
dtypes=[onp.float64, onp.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.sin, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.cos, nargs=1, order=2, rng=jtu.rand_default(),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.tan, nargs=1, order=2, rng=jtu.rand_uniform(-1.3, 1.3),
|
||||
dtypes=[onp.float64, onp.complex64]),
|
||||
grad_test_spec(lax.asin, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
|
||||
dtypes=[onp.float64]),
|
||||
grad_test_spec(lax.acos, nargs=1, order=2, rng=jtu.rand_uniform(-1., 1.),
|
||||
dtypes=[onp.float64]),
|
||||
# TODO(proteneer): atan2 input is already a representation of a
|
||||
# complex number. Need to think harder about what this even means
|
||||
# if each input itself is a complex number.
|
||||
@ -1556,6 +1584,26 @@ LAX_GRAD_OPS = [
|
||||
# dtypes=[onp.float64], name="MinSomeEqual"),
|
||||
]
|
||||
|
||||
GradSpecialValuesTestSpec = collections.namedtuple(
|
||||
"GradSpecialValuesTestSpec", ["op", "values"])
|
||||
|
||||
LAX_GRAD_SPECIAL_VALUE_TESTS = [
|
||||
GradSpecialValuesTestSpec(lax.sinh, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.cosh, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.tanh, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lax.asinh, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lax.acosh, [1000.]),
|
||||
GradSpecialValuesTestSpec(lax.atanh, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
GradSpecialValuesTestSpec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
GradSpecialValuesTestSpec(lax.tan, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.asin, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.acos, [0.]),
|
||||
GradSpecialValuesTestSpec(lax.atan, [0., 1000.]),
|
||||
GradSpecialValuesTestSpec(lax.erf, [0., 10.]),
|
||||
GradSpecialValuesTestSpec(lax.erfc, [0., 10.]),
|
||||
]
|
||||
|
||||
|
||||
def check_grads_bilinear(f, args, order,
|
||||
modes=["fwd", "rev"], atol=None, rtol=None):
|
||||
@ -1581,13 +1629,21 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for dtype in rec.dtypes)
|
||||
for rec in LAX_GRAD_OPS))
|
||||
def testOpGrad(self, op, rng, shapes, dtype, order, tol):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
if jtu.device_under_test() == "tpu" and op is lax.pow:
|
||||
raise SkipTest("pow grad imprecise on tpu")
|
||||
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
|
||||
args = tuple(rng(shape, dtype) for shape in shapes)
|
||||
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
|
||||
"op": rec.op, "special_value": special_value}
|
||||
for special_value in rec.values)
|
||||
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
|
||||
def testOpGradSpecialValue(self, op, special_value):
|
||||
check_grads(op, (special_value,), 2, ["fwd", "rev"])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
||||
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
|
||||
@ -2242,17 +2298,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
expected = onp.array(0.0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user