Remove jax.xla_computation tests from jax2tf. api_test.py has enough coverage for jax.xla_computation

PiperOrigin-RevId: 644605636
This commit is contained in:
Yash Katariya 2024-06-18 21:01:08 -07:00 committed by jax authors
parent 103d620856
commit d3bfd32667
2 changed files with 1 additions and 3 deletions

View File

@ -674,7 +674,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
def test_module_documentation(self):
def cos_tf(x):
return tf.math.cos(x)
@ -696,7 +695,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
jax.grad(cos_tf_sin_jax)(x)
logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x))
logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
def test_tf_gather(self):
"""tf_gather gradient output is tf.IndexSlices."""

View File

@ -1314,7 +1314,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
shape = (3, 2)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
jax_comp = jax.jit(f_while).lower(x).compiler_ir('hlo')
backend = xb.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()