fix import in api_test.py

This commit is contained in:
Matthew Johnson 2019-02-22 07:56:13 -08:00
parent 72e2e5da06
commit 121e74f5cc

View File

@ -26,7 +26,6 @@ import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
from jax import api
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
from jax.interpreters.ad import defjvp
from jax.interpreters.xla import DeviceArray
from jax.abstract_arrays import concretization_err_msg
@ -216,7 +215,7 @@ class APITest(jtu.JaxTestCase):
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Forward-mode differentiation rule for 'foo' not implemented")
def_abstract_eval(foo_p, lambda x: x)
foo_p.def_abstract_eval(lambda x: x)
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"XLA translation rule for 'foo' not implemented")