[jax2tf] Simplify irrelevant part of call_tf_test.py

PiperOrigin-RevId: 497727816
This commit is contained in:
George Necula 2022-12-25 20:26:26 -08:00 committed by jax authors
parent 398aaaacc7
commit 0c8a4fb7cd

View File

@ -542,7 +542,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
def fun_tf(x):
begin = 0
return x[begin:5] # x must be a compile-time constant
return x[begin:5]
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
@ -550,16 +550,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
# Non-constant slice, but compile-time constant depending only on values.
x = np.zeros((10,), dtype=np.int32)
def fun_tf(x):
begin = x[0]
return x[begin:5] # x must be a compile-time constant
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
self.assertIn("() -> s32[5]", hlo)
x = np.ones((10,), dtype=np.int32)
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
self.assertIn("() -> s32[4]", hlo)
# Non-constant slice, but compile-time constant depending only on shapes.
x = np.zeros((10,), dtype=np.int32)