[jax2tf] Add test that compile_args[tuple_args] does not matter for serialization

PiperOrigin-RevId: 521422653
This commit is contained in:
George Necula 2023-04-03 04:31:55 -07:00 committed by jax authors
parent 2ce78ac9a8
commit bf2c07121b
2 changed files with 21 additions and 1 deletions

View File

@ -77,6 +77,7 @@ class Exported:
# The in_shardings reflect only the module_kept_var_idx
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
mlir_module: mlir.ir.Module
@ -136,7 +137,8 @@ def serialize_native(fun_jax: Callable,
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
("tuple_args", lambda v: not v, "False"),
# tuple_args is a compilation flag, does not affect lowering.
("tuple_args", lambda v: True, "N/A"),
# Used for debug(ordered=True), changes the calling convention, but will
# also set keepalive to non-empty.
("ordered_effects", lambda v: not v, "empty"),

View File

@ -1621,6 +1621,24 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
"keepalive must be empty"):
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
def test_tuple_args(self):
# On TPU if we have more than 2000 arguments, we pass them as a tuple.
# This is a compiler option, and should have no effect on lowering.
if jtu.device_under_test() != "tpu":
raise unittest.SkipTest("Test enabled on TPU only")
def f_jax(*many_args):
acc = 0.
for a in many_args:
acc += a
return acc
many_args = [np.float32(i) for i in range(2001)]
# Test that we do set lowered.compile_args[tuple_args]
lowered = jax.jit(f_jax).lower(*many_args)
self.assertTrue(lowered._lowering.compile_args["tuple_args"])
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
self.assertAllClose(f_jax(*many_args), res)
def get_serialized_computation(
f_jax: Callable,