This commit is contained in:
James Bradbury 2019-10-09 17:02:11 -07:00
parent fb433fb9d2
commit 9d2f25cf1a
2 changed files with 11 additions and 0 deletions

View File

@ -400,6 +400,7 @@ def concatenate(operands, dimension):
operand_shapes=tuple(o.shape for o in operands))
Precision = xla_client.PrecisionConfig.Precision
Precision.__str__ = lambda precision: precision.name
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
rhs_dilation=None, dimension_numbers=None,

View File

@ -1812,6 +1812,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
# check that precision config is preserved
result, pullback = api.vjp(dot, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1837,6 +1842,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
precision=lax.Precision.HIGHEST)
check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
# check that precision config is preserved
result, pullback = api.vjp(dot_general, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(