mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix import in api_test.py
This commit is contained in:
parent
72e2e5da06
commit
121e74f5cc
@ -26,7 +26,6 @@ import jax.numpy as np
|
||||
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
|
||||
from jax import api
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters.partial_eval import def_abstract_eval
|
||||
from jax.interpreters.ad import defjvp
|
||||
from jax.interpreters.xla import DeviceArray
|
||||
from jax.abstract_arrays import concretization_err_msg
|
||||
@ -216,7 +215,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
||||
"Forward-mode differentiation rule for 'foo' not implemented")
|
||||
|
||||
def_abstract_eval(foo_p, lambda x: x)
|
||||
foo_p.def_abstract_eval(lambda x: x)
|
||||
|
||||
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
||||
"XLA translation rule for 'foo' not implemented")
|
||||
|
Loading…
x
Reference in New Issue
Block a user