Test: use context manager to set jax_serialization_version

This commit is contained in:
Jake VanderPlas 2024-06-04 16:08:16 -07:00
parent 8f090b3465
commit 9a080f4b83
2 changed files with 258 additions and 253 deletions

View File

@ -27,8 +27,8 @@ from jax import dlpack
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import export
@ -1149,17 +1149,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
_ = tf.add(1, 1)
super().setUp()
def override_serialization_version(self, version_override: int):
version = jax.config.jax_serialization_version
if version != version_override:
self.addCleanup(partial(jax.config.update,
"jax_serialization_version",
version))
jax.config.update("jax_serialization_version", version_override)
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX
f_tf_inner = tf.math.sin
@ -1660,116 +1649,127 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
@jtu.parameterized_filterable(
kwargs=[dict(version=version) for version in [8, 9]]
kwargs=[dict(version=version) for version in [9]]
)
def test_call_tf_graph_ordered(self, *, version: int):
self.override_serialization_version(version)
@tf.function
def tf_print(x):
tf.print(x)
with config.jax_serialization_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
call_tf_print = jax2tf.call_tf(
tf_print,
call_tf_graph=True,
ordered=True,
)
@tf.function
def tf_print(x):
tf.print(x)
x = jnp.array(1.0, dtype=jnp.float32)
call_tf_print = jax2tf.call_tf(
tf_print,
call_tf_graph=True,
ordered=True,
)
def body(i, x):
call_tf_print(x)
return x + 1
x = jnp.array(1.0, dtype=jnp.float32)
@jax.jit
def f_jax(x):
return jax.lax.fori_loop(0, 4, body, x)
def body(i, x):
call_tf_print(x)
return x + 1
num_custom_calls = 0
@jax.jit
def f_jax(x):
return jax.lax.fori_loop(0, 4, body, x)
def _check_mlir_ops(op):
nonlocal num_custom_calls
num_custom_calls = 0
if (
op.operation.name == "stablehlo.custom_call"
and ir.StringAttr(op.attributes["call_target_name"]).value
== "tf.call_tf_function"
def _check_mlir_ops(op):
nonlocal num_custom_calls
if (
op.operation.name == "stablehlo.custom_call"
and ir.StringAttr(op.attributes["call_target_name"]).value
== "tf.call_tf_function"
):
num_custom_calls += 1
# The custom call op must have `has_token_input_output` attribute.
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
self.assertTrue(
ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
)
# Verify that the first argument/result of the custom call op is a token
# type. This is a calling convention defined by `has_token_input_output`.
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
stablehlo_module = None
with self.assertRaisesRegex(
ValueError,
"call_tf_graph=True only support exporting by jax2tf.convert currently",
):
num_custom_calls += 1
lower = f_jax.lower(x)
self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
stablehlo_module = lower.compiler_ir("stablehlo")
if stablehlo_module:
self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
self.assertEqual(num_custom_calls, 1)
# The custom call op must have `has_token_input_output` attribute.
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
self.assertTrue(
ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
)
# Verify that the first argument/result of the custom call op is a token
# type. This is a calling convention defined by `has_token_input_output`.
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
stablehlo_module = None
with self.assertRaisesRegex(
ValueError,
"call_tf_graph=True only support exporting by jax2tf.convert currently",
):
lower = f_jax.lower(x)
self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
stablehlo_module = lower.compiler_ir("stablehlo")
if stablehlo_module:
self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
self.assertEqual(num_custom_calls, 1)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
with_gradient=False,
)
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
with_gradient=False,
)
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
@jtu.parameterized_filterable(
kwargs=[dict(poly=poly, version=version)
for poly in [True, False]
for version in [8, 9]]
for version in [9]]
)
def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
self.override_serialization_version(version)
def f_jax(x1, x_dead, x3):
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
call_tf_graph=True)(x3))
if poly:
polymorphic_shapes = ["b", None, None]
else:
polymorphic_shapes = None
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
x1 = np.arange(3, dtype=np.float32)
x_dead = np.arange(4, dtype=np.float32)
x3 = np.arange(5, dtype=np.float32)
self.assertAllClose(f_jax(x1, x_dead, x3),
f_tf(x1, x_dead, x3))
with config.jax_serialization_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
def f_jax(x1, x_dead, x3):
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
call_tf_graph=True)(x3))
if poly:
polymorphic_shapes = ["b", None, None]
else:
polymorphic_shapes = None
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
x1 = np.arange(3, dtype=np.float32)
x_dead = np.arange(4, dtype=np.float32)
x3 = np.arange(5, dtype=np.float32)
self.assertAllClose(f_jax(x1, x_dead, x3),
f_tf(x1, x_dead, x3))
@jtu.parameterized_filterable(
kwargs=[dict(ordered=ordered, version=version)
for ordered in [True, False]
for version in [8, 9]
for version in [9]
]
)
def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
self.override_serialization_version(version)
@tf.function(jit_compile=True, autograph=False)
@partial(jax2tf.convert,
with_gradient=False,
native_serialization=True,
polymorphic_shapes=["(b)"])
@jax.jit
def tf_f_2(x):
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
jax2tf.call_tf(tf_f,
call_tf_graph=True,
ordered=ordered)(x)
return x
with config.jax_serialization_version(version):
logging.info(
"Using JAX serialization version %s",
jax.config.jax_serialization_version)
x = np.arange(3, dtype=np.int32)
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
@tf.function(jit_compile=True, autograph=False)
@partial(jax2tf.convert,
with_gradient=False,
native_serialization=True,
polymorphic_shapes=["(b)"])
@jax.jit
def tf_f_2(x):
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
jax2tf.call_tf(tf_f,
call_tf_graph=True,
ordered=ordered)(x)
return x
x = np.arange(3, dtype=np.int32)
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
# TODO(b/293927250): call_tf_graph=True only accept concrete_function. The
# workaround here is to set `module.call=concrete_fn.`.

View File

@ -148,15 +148,10 @@ def get_exported(fun, vjp_order=0,
return export.deserialize(serialized)
return serde_exported
class JaxExportTest(jtu.JaxTestCase):
def override_serialization_version(self, version_override: int):
version = config.jax_serialization_version.value
if version != version_override:
self.enter_context(config.jax_serialization_version(version_override))
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
# Run tests with the maximum supported version by default
@jtu.with_config(jax_serialization_version=export.maximum_supported_serialization_version)
class JaxExportTest(jtu.JaxTestCase):
@classmethod
def setUpClass(cls):
@ -170,12 +165,6 @@ class JaxExportTest(jtu.JaxTestCase):
cls.platforms.append(backend)
super().setUpClass()
def setUp(self):
super().setUp()
# Run tests with the maximum supported version by default
self.override_serialization_version(
export.maximum_supported_serialization_version)
def test_basic_export_only(self):
def my_fun(x):
return jnp.sin(x)
@ -563,19 +552,22 @@ class JaxExportTest(jtu.JaxTestCase):
for v in range(export.minimum_supported_serialization_version - 1,
export.maximum_supported_serialization_version + 2)])
def test_poly_basic_versions(self, v: int):
self.override_serialization_version(v)
with contextlib.ExitStack() as e:
if not (export.minimum_supported_serialization_version <= v
<= export.maximum_supported_serialization_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))
with config.jax_serialization_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
with contextlib.ExitStack() as e:
if not (export.minimum_supported_serialization_version <= v
<= export.maximum_supported_serialization_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use export.call_exported and we also run the shape check
@ -1375,74 +1367,77 @@ class JaxExportTest(jtu.JaxTestCase):
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
def test_ordered_effects_basic(self, *, v: int):
self.override_serialization_version(v)
x = np.arange(3, dtype=np.float32)
def f_jax(x): # x: f32[3]
# Test also the calling convention for inner functions
def f_jax_inner(x):
with config.jax_serialization_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
x = np.arange(3, dtype=np.float32)
def f_jax(x): # x: f32[3]
# Test also the calling convention for inner functions
def f_jax_inner(x):
return (
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
return (
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
return (
10. +
jax.jit(f_jax_inner)(x) +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
10. +
jax.jit(f_jax_inner)(x) +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
)
exp = get_exported(f_jax)(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"],
[str(e) for e in exp.unordered_effects])
mlir_module_str = str(exp.mlir_module())
# Inner functions use stablehlo.token for all versions
inner_fun_expected_re = (
r"func.func private @f_jax_inner\("
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: tensor<3xf32>.*->.*"
# Results
r"!stablehlo.token .*jax.token = true.*"
r"tensor<3xf32>"
)
self.assertRegex(mlir_module_str, inner_fun_expected_re)
exp = get_exported(f_jax)(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"],
[str(e) for e in exp.unordered_effects])
mlir_module_str = str(exp.mlir_module())
# The wrapped_main function takens tokens after version 9, and takes
# i1[0] before version 9.
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
# Results
r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
# Inner functions use stablehlo.token for all versions
inner_fun_expected_re = (
r"func.func private @f_jax_inner\("
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: tensor<3xf32>.*->.*"
# Results
r"!stablehlo.token .*jax.token = true.*"
r"tensor<3xf32>"
)
self.assertRegex(mlir_module_str, inner_fun_expected_re)
# The main function takes tokens and has the same type as the wrapped main
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
self.assertRegex(mlir_module_str, main_expected_re)
# The wrapped_main function takens tokens after version 9, and takes
# i1[0] before version 9.
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
# Results
r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
# Now call the exported from a function that uses its own effects
def f_outer(x):
return (
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x))
# The main function takes tokens and has the same type as the wrapped main
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
self.assertRegex(mlir_module_str, main_expected_re)
lowered_outer = jax.jit(f_outer).lower(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
self.assertEqual(["ForTestingUnorderedEffect1()"],
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
# Now call the exported from a function that uses its own effects
def f_outer(x):
return (
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x))
mlir_outer_module_str = str(lowered_outer.compiler_ir())
self.assertRegex(mlir_outer_module_str, main_expected_re)
lowered_outer = jax.jit(f_outer).lower(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
self.assertEqual(["ForTestingUnorderedEffect1()"],
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
mlir_outer_module_str = str(lowered_outer.compiler_ir())
self.assertRegex(mlir_outer_module_str, main_expected_re)
res = jax.jit(f_outer)(x)
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
res = jax.jit(f_outer)(x)
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
@jtu.parameterized_filterable(
kwargs=[
@ -1450,33 +1445,36 @@ class JaxExportTest(jtu.JaxTestCase):
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
def test_ordered_effects_poly(self, *, v: int):
self.override_serialization_version(v)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
export.symbolic_shape("b2, b1"), x.dtype))
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
r"%arg2: !stablehlo.token {jax.token = true.* "
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
with config.jax_serialization_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
export.symbolic_shape("b2, b1"), x.dtype))
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
r"%arg2: !stablehlo.token {jax.token = true.* "
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
main_expected_re = (
r"@main\("
r"%arg0: !stablehlo.token {jax.token = true.*, "
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
main_expected_re = (
r"@main\("
r"%arg0: !stablehlo.token {jax.token = true.*, "
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
self.assertAllClose(10. + 2. * x, res)
res = export.call_exported(exp)(x)
self.assertAllClose(10. + 2. * x, res)
@jtu.parameterized_filterable(
kwargs=[
@ -1484,41 +1482,44 @@ class JaxExportTest(jtu.JaxTestCase):
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
self.override_serialization_version(v)
if jtu.device_under_test() == "gpu":
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x): # x: f32[b1, b2]
return 10. + _testing_multi_platform_func(x,
effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(
f_jax,
lowering_platforms=("cpu", "tpu")
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
r"%arg3: !stablehlo.token {jax.token = true.*, "
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
with config.jax_serialization_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
if jtu.device_under_test() == "gpu":
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x): # x: f32[b1, b2]
return 10. + _testing_multi_platform_func(x,
effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(
f_jax,
lowering_platforms=("cpu", "tpu")
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
r"%arg3: !stablehlo.token {jax.token = true.*, "
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
main_expected_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: !stablehlo.token {jax.token = true.*, "
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
res)
main_expected_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: !stablehlo.token {jax.token = true.*, "
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
res)
@jtu.parameterized_filterable(
kwargs=[
@ -1526,19 +1527,23 @@ class JaxExportTest(jtu.JaxTestCase):
for v in range(export.minimum_supported_serialization_version,
export.maximum_supported_serialization_version + 1)])
def test_ordered_effects_with_donation(self, *, v: int):
self.override_serialization_version(v)
x = np.arange(3, dtype=np.float32)
with config.jax_serialization_version(v):
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version.value)
def f_jax(x):
return testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingOrderedEffect1"
)
x = np.arange(3, dtype=np.float32)
f_jax = jax.jit(f_jax, donate_argnums=(0,))
exp = export.export(f_jax)(x)
mlir_module_str = str(exp.mlir_module())
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
def f_jax(x):
return testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingOrderedEffect1"
)
f_jax = jax.jit(f_jax, donate_argnums=(0,))
exp = export.export(f_jax)(x)
mlir_module_str = str(exp.mlir_module())
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
@jtu.parameterized_filterable(
kwargs=[