mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add test
This commit is contained in:
parent
fb433fb9d2
commit
9d2f25cf1a
@ -400,6 +400,7 @@ def concatenate(operands, dimension):
|
||||
operand_shapes=tuple(o.shape for o in operands))
|
||||
|
||||
Precision = xla_client.PrecisionConfig.Precision
|
||||
Precision.__str__ = lambda precision: precision.name
|
||||
|
||||
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
|
||||
rhs_dilation=None, dimension_numbers=None,
|
||||
|
@ -1812,6 +1812,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
||||
check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=tol, rtol=tol)
|
||||
# check that precision config is preserved
|
||||
result, pullback = api.vjp(dot, lhs, rhs)
|
||||
gresult = lax.zeros_like_array(result)
|
||||
s = str(api.make_jaxpr(pullback)(gresult))
|
||||
assert "precision=HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1837,6 +1842,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
precision=lax.Precision.HIGHEST)
|
||||
check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=tol, rtol=tol)
|
||||
# check that precision config is preserved
|
||||
result, pullback = api.vjp(dot_general, lhs, rhs)
|
||||
gresult = lax.zeros_like_array(result)
|
||||
s = str(api.make_jaxpr(pullback)(gresult))
|
||||
assert "precision=HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user