mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add test that compile_args[tuple_args] does not matter for serialization
PiperOrigin-RevId: 521422653
This commit is contained in:
parent
2ce78ac9a8
commit
bf2c07121b
@ -77,6 +77,7 @@ class Exported:
|
||||
# The in_shardings reflect only the module_kept_var_idx
|
||||
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
|
||||
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
|
||||
|
||||
mlir_module: mlir.ir.Module
|
||||
@ -136,7 +137,8 @@ def serialize_native(fun_jax: Callable,
|
||||
for compile_arg, check_value, err_msg in (
|
||||
("spmd_lowering", lambda v: v, "True"),
|
||||
("auto_spmd_lowering", lambda v: not v, "False"),
|
||||
("tuple_args", lambda v: not v, "False"),
|
||||
# tuple_args is a compilation flag, does not affect lowering.
|
||||
("tuple_args", lambda v: True, "N/A"),
|
||||
# Used for debug(ordered=True), changes the calling convention, but will
|
||||
# also set keepalive to non-empty.
|
||||
("ordered_effects", lambda v: not v, "empty"),
|
||||
|
@ -1621,6 +1621,24 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
"keepalive must be empty"):
|
||||
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
|
||||
|
||||
def test_tuple_args(self):
|
||||
# On TPU if we have more than 2000 arguments, we pass them as a tuple.
|
||||
# This is a compiler option, and should have no effect on lowering.
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise unittest.SkipTest("Test enabled on TPU only")
|
||||
def f_jax(*many_args):
|
||||
acc = 0.
|
||||
for a in many_args:
|
||||
acc += a
|
||||
return acc
|
||||
|
||||
many_args = [np.float32(i) for i in range(2001)]
|
||||
# Test that we do set lowered.compile_args[tuple_args]
|
||||
lowered = jax.jit(f_jax).lower(*many_args)
|
||||
self.assertTrue(lowered._lowering.compile_args["tuple_args"])
|
||||
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
|
||||
self.assertAllClose(f_jax(*many_args), res)
|
||||
|
||||
|
||||
def get_serialized_computation(
|
||||
f_jax: Callable,
|
||||
|
Loading…
x
Reference in New Issue
Block a user