mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove jax.xla_computation
tests from jax2tf. api_test.py
has enough coverage for jax.xla_computation
PiperOrigin-RevId: 644605636
This commit is contained in:
parent
103d620856
commit
d3bfd32667
@ -674,7 +674,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
|
||||
self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
|
||||
|
||||
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
|
||||
def test_module_documentation(self):
|
||||
def cos_tf(x):
|
||||
return tf.math.cos(x)
|
||||
@ -696,7 +695,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
jax.grad(cos_tf_sin_jax)(x)
|
||||
|
||||
logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
||||
logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
||||
|
||||
def test_tf_gather(self):
|
||||
"""tf_gather gradient output is tf.IndexSlices."""
|
||||
|
@ -1314,7 +1314,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
shape = (3, 2)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
jax_comp = jax.xla_computation(f_while)(x)
|
||||
jax_comp = jax.jit(f_while).lower(x).compiler_ir('hlo')
|
||||
backend = xb.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
|
Loading…
x
Reference in New Issue
Block a user