mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
[jax2tf] Simplify irrelevant part of call_tf_test.py
PiperOrigin-RevId: 497727816
This commit is contained in:
parent
398aaaacc7
commit
0c8a4fb7cd
@ -542,7 +542,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def fun_tf(x):
|
||||
begin = 0
|
||||
return x[begin:5] # x must be a compile-time constant
|
||||
return x[begin:5]
|
||||
|
||||
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
|
||||
@ -550,16 +550,6 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
# Non-constant slice, but compile-time constant depending only on values.
|
||||
x = np.zeros((10,), dtype=np.int32)
|
||||
|
||||
def fun_tf(x):
|
||||
begin = x[0]
|
||||
return x[begin:5] # x must be a compile-time constant
|
||||
|
||||
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("() -> s32[5]", hlo)
|
||||
x = np.ones((10,), dtype=np.int32)
|
||||
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("() -> s32[4]", hlo)
|
||||
|
||||
# Non-constant slice, but compile-time constant depending only on shapes.
|
||||
x = np.zeros((10,), dtype=np.int32)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user