mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Test: use context manager to set jax_serialization_version
This commit is contained in:
parent
8f090b3465
commit
9a080f4b83
@ -27,8 +27,8 @@ from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import export
|
||||
@ -1149,17 +1149,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
|
||||
def override_serialization_version(self, version_override: int):
|
||||
version = jax.config.jax_serialization_version
|
||||
if version != version_override:
|
||||
self.addCleanup(partial(jax.config.update,
|
||||
"jax_serialization_version",
|
||||
version))
|
||||
jax.config.update("jax_serialization_version", version_override)
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
|
||||
def test_alternate(self):
|
||||
# Alternate sin/cos with sin in TF and cos in JAX
|
||||
f_tf_inner = tf.math.sin
|
||||
@ -1660,116 +1649,127 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[dict(version=version) for version in [8, 9]]
|
||||
kwargs=[dict(version=version) for version in [9]]
|
||||
)
|
||||
def test_call_tf_graph_ordered(self, *, version: int):
|
||||
self.override_serialization_version(version)
|
||||
@tf.function
|
||||
def tf_print(x):
|
||||
tf.print(x)
|
||||
with config.jax_serialization_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
|
||||
call_tf_print = jax2tf.call_tf(
|
||||
tf_print,
|
||||
call_tf_graph=True,
|
||||
ordered=True,
|
||||
)
|
||||
@tf.function
|
||||
def tf_print(x):
|
||||
tf.print(x)
|
||||
|
||||
x = jnp.array(1.0, dtype=jnp.float32)
|
||||
call_tf_print = jax2tf.call_tf(
|
||||
tf_print,
|
||||
call_tf_graph=True,
|
||||
ordered=True,
|
||||
)
|
||||
|
||||
def body(i, x):
|
||||
call_tf_print(x)
|
||||
return x + 1
|
||||
x = jnp.array(1.0, dtype=jnp.float32)
|
||||
|
||||
@jax.jit
|
||||
def f_jax(x):
|
||||
return jax.lax.fori_loop(0, 4, body, x)
|
||||
def body(i, x):
|
||||
call_tf_print(x)
|
||||
return x + 1
|
||||
|
||||
num_custom_calls = 0
|
||||
@jax.jit
|
||||
def f_jax(x):
|
||||
return jax.lax.fori_loop(0, 4, body, x)
|
||||
|
||||
def _check_mlir_ops(op):
|
||||
nonlocal num_custom_calls
|
||||
num_custom_calls = 0
|
||||
|
||||
if (
|
||||
op.operation.name == "stablehlo.custom_call"
|
||||
and ir.StringAttr(op.attributes["call_target_name"]).value
|
||||
== "tf.call_tf_function"
|
||||
def _check_mlir_ops(op):
|
||||
nonlocal num_custom_calls
|
||||
|
||||
if (
|
||||
op.operation.name == "stablehlo.custom_call"
|
||||
and ir.StringAttr(op.attributes["call_target_name"]).value
|
||||
== "tf.call_tf_function"
|
||||
):
|
||||
num_custom_calls += 1
|
||||
|
||||
# The custom call op must have `has_token_input_output` attribute.
|
||||
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
|
||||
self.assertTrue(
|
||||
ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
|
||||
)
|
||||
|
||||
# Verify that the first argument/result of the custom call op is a token
|
||||
# type. This is a calling convention defined by `has_token_input_output`.
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
|
||||
stablehlo_module = None
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"call_tf_graph=True only support exporting by jax2tf.convert currently",
|
||||
):
|
||||
num_custom_calls += 1
|
||||
lower = f_jax.lower(x)
|
||||
self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
|
||||
stablehlo_module = lower.compiler_ir("stablehlo")
|
||||
if stablehlo_module:
|
||||
self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
|
||||
self.assertEqual(num_custom_calls, 1)
|
||||
|
||||
# The custom call op must have `has_token_input_output` attribute.
|
||||
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
|
||||
self.assertTrue(
|
||||
ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
|
||||
)
|
||||
|
||||
# Verify that the first argument/result of the custom call op is a token
|
||||
# type. This is a calling convention defined by `has_token_input_output`.
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
|
||||
stablehlo_module = None
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"call_tf_graph=True only support exporting by jax2tf.convert currently",
|
||||
):
|
||||
lower = f_jax.lower(x)
|
||||
self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
|
||||
stablehlo_module = lower.compiler_ir("stablehlo")
|
||||
if stablehlo_module:
|
||||
self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
|
||||
self.assertEqual(num_custom_calls, 1)
|
||||
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
with_gradient=False,
|
||||
)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[dict(poly=poly, version=version)
|
||||
for poly in [True, False]
|
||||
for version in [8, 9]]
|
||||
for version in [9]]
|
||||
)
|
||||
def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
|
||||
self.override_serialization_version(version)
|
||||
def f_jax(x1, x_dead, x3):
|
||||
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
|
||||
call_tf_graph=True)(x3))
|
||||
if poly:
|
||||
polymorphic_shapes = ["b", None, None]
|
||||
else:
|
||||
polymorphic_shapes = None
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
|
||||
x1 = np.arange(3, dtype=np.float32)
|
||||
x_dead = np.arange(4, dtype=np.float32)
|
||||
x3 = np.arange(5, dtype=np.float32)
|
||||
self.assertAllClose(f_jax(x1, x_dead, x3),
|
||||
f_tf(x1, x_dead, x3))
|
||||
with config.jax_serialization_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
def f_jax(x1, x_dead, x3):
|
||||
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
|
||||
call_tf_graph=True)(x3))
|
||||
if poly:
|
||||
polymorphic_shapes = ["b", None, None]
|
||||
else:
|
||||
polymorphic_shapes = None
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
|
||||
x1 = np.arange(3, dtype=np.float32)
|
||||
x_dead = np.arange(4, dtype=np.float32)
|
||||
x3 = np.arange(5, dtype=np.float32)
|
||||
self.assertAllClose(f_jax(x1, x_dead, x3),
|
||||
f_tf(x1, x_dead, x3))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[dict(ordered=ordered, version=version)
|
||||
for ordered in [True, False]
|
||||
for version in [8, 9]
|
||||
for version in [9]
|
||||
]
|
||||
)
|
||||
def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
|
||||
self.override_serialization_version(version)
|
||||
@tf.function(jit_compile=True, autograph=False)
|
||||
@partial(jax2tf.convert,
|
||||
with_gradient=False,
|
||||
native_serialization=True,
|
||||
polymorphic_shapes=["(b)"])
|
||||
@jax.jit
|
||||
def tf_f_2(x):
|
||||
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
|
||||
jax2tf.call_tf(tf_f,
|
||||
call_tf_graph=True,
|
||||
ordered=ordered)(x)
|
||||
return x
|
||||
with config.jax_serialization_version(version):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
jax.config.jax_serialization_version)
|
||||
|
||||
x = np.arange(3, dtype=np.int32)
|
||||
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
|
||||
@tf.function(jit_compile=True, autograph=False)
|
||||
@partial(jax2tf.convert,
|
||||
with_gradient=False,
|
||||
native_serialization=True,
|
||||
polymorphic_shapes=["(b)"])
|
||||
@jax.jit
|
||||
def tf_f_2(x):
|
||||
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
|
||||
jax2tf.call_tf(tf_f,
|
||||
call_tf_graph=True,
|
||||
ordered=ordered)(x)
|
||||
return x
|
||||
|
||||
x = np.arange(3, dtype=np.int32)
|
||||
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
|
||||
|
||||
# TODO(b/293927250): call_tf_graph=True only accept concrete_function. The
|
||||
# workaround here is to set `module.call=concrete_fn.`.
|
||||
|
@ -148,15 +148,10 @@ def get_exported(fun, vjp_order=0,
|
||||
return export.deserialize(serialized)
|
||||
return serde_exported
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def override_serialization_version(self, version_override: int):
|
||||
version = config.jax_serialization_version.value
|
||||
if version != version_override:
|
||||
self.enter_context(config.jax_serialization_version(version_override))
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
# Run tests with the maximum supported version by default
|
||||
@jtu.with_config(jax_serialization_version=export.maximum_supported_serialization_version)
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -170,12 +165,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
cls.platforms.append(backend)
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Run tests with the maximum supported version by default
|
||||
self.override_serialization_version(
|
||||
export.maximum_supported_serialization_version)
|
||||
|
||||
def test_basic_export_only(self):
|
||||
def my_fun(x):
|
||||
return jnp.sin(x)
|
||||
@ -563,19 +552,22 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for v in range(export.minimum_supported_serialization_version - 1,
|
||||
export.maximum_supported_serialization_version + 2)])
|
||||
def test_poly_basic_versions(self, v: int):
|
||||
self.override_serialization_version(v)
|
||||
with contextlib.ExitStack() as e:
|
||||
if not (export.minimum_supported_serialization_version <= v
|
||||
<= export.maximum_supported_serialization_version):
|
||||
e.enter_context(self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"The requested jax_serialization version {v} is outside the range of supported versions"))
|
||||
with config.jax_serialization_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
with contextlib.ExitStack() as e:
|
||||
if not (export.minimum_supported_serialization_version <= v
|
||||
<= export.maximum_supported_serialization_version):
|
||||
e.enter_context(self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"The requested jax_serialization version {v} is outside the range of supported versions"))
|
||||
|
||||
exp = get_exported(jnp.sin)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
||||
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
exp = get_exported(jnp.sin)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
||||
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
|
||||
# A function is exported with f32[poly_spec] and is called with different arg
|
||||
# shapes. We use export.call_exported and we also run the shape check
|
||||
@ -1375,74 +1367,77 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
def test_ordered_effects_basic(self, *, v: int):
|
||||
self.override_serialization_version(v)
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
def f_jax(x): # x: f32[3]
|
||||
# Test also the calling convention for inner functions
|
||||
def f_jax_inner(x):
|
||||
with config.jax_serialization_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
def f_jax(x): # x: f32[3]
|
||||
# Test also the calling convention for inner functions
|
||||
def f_jax_inner(x):
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
|
||||
return (
|
||||
10. +
|
||||
jax.jit(f_jax_inner)(x) +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
|
||||
10. +
|
||||
jax.jit(f_jax_inner)(x) +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
|
||||
)
|
||||
|
||||
exp = get_exported(f_jax)(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in exp.ordered_effects))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
[str(e) for e in exp.unordered_effects])
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
|
||||
# Inner functions use stablehlo.token for all versions
|
||||
inner_fun_expected_re = (
|
||||
r"func.func private @f_jax_inner\("
|
||||
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
||||
r"%arg1: tensor<3xf32>.*->.*"
|
||||
# Results
|
||||
r"!stablehlo.token .*jax.token = true.*"
|
||||
r"tensor<3xf32>"
|
||||
)
|
||||
self.assertRegex(mlir_module_str, inner_fun_expected_re)
|
||||
|
||||
exp = get_exported(f_jax)(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in exp.ordered_effects))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
[str(e) for e in exp.unordered_effects])
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
# The wrapped_main function takens tokens after version 9, and takes
|
||||
# i1[0] before version 9.
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
||||
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
|
||||
# Results
|
||||
r"!stablehlo.token .*jax.token = true.*"
|
||||
r"!stablehlo.token .*jax.token = true.*")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
# Inner functions use stablehlo.token for all versions
|
||||
inner_fun_expected_re = (
|
||||
r"func.func private @f_jax_inner\("
|
||||
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
||||
r"%arg1: tensor<3xf32>.*->.*"
|
||||
# Results
|
||||
r"!stablehlo.token .*jax.token = true.*"
|
||||
r"tensor<3xf32>"
|
||||
)
|
||||
self.assertRegex(mlir_module_str, inner_fun_expected_re)
|
||||
# The main function takes tokens and has the same type as the wrapped main
|
||||
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
|
||||
# The wrapped_main function takens tokens after version 9, and takes
|
||||
# i1[0] before version 9.
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
||||
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
|
||||
# Results
|
||||
r"!stablehlo.token .*jax.token = true.*"
|
||||
r"!stablehlo.token .*jax.token = true.*")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
# Now call the exported from a function that uses its own effects
|
||||
def f_outer(x):
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingUnorderedEffect1") +
|
||||
export.call_exported(exp)(x))
|
||||
|
||||
# The main function takes tokens and has the same type as the wrapped main
|
||||
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
|
||||
|
||||
# Now call the exported from a function that uses its own effects
|
||||
def f_outer(x):
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingUnorderedEffect1") +
|
||||
export.call_exported(exp)(x))
|
||||
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
||||
self.assertRegex(mlir_outer_module_str, main_expected_re)
|
||||
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
|
||||
|
||||
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
||||
self.assertRegex(mlir_outer_module_str, main_expected_re)
|
||||
|
||||
res = jax.jit(f_outer)(x)
|
||||
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
|
||||
res = jax.jit(f_outer)(x)
|
||||
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1450,33 +1445,36 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
def test_ordered_effects_poly(self, *, v: int):
|
||||
self.override_serialization_version(v)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
|
||||
export.symbolic_shape("b2, b1"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
|
||||
r"%arg2: !stablehlo.token {jax.token = true.* "
|
||||
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
with config.jax_serialization_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
|
||||
export.symbolic_shape("b2, b1"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
|
||||
r"%arg2: !stablehlo.token {jax.token = true.* "
|
||||
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(10. + 2. * x, res)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(10. + 2. * x, res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1484,41 +1482,44 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
|
||||
self.override_serialization_version(v)
|
||||
if jtu.device_under_test() == "gpu":
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
x = np.ones((3, 4), dtype=np.float32)
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + _testing_multi_platform_func(x,
|
||||
effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(
|
||||
f_jax,
|
||||
lowering_platforms=("cpu", "tpu")
|
||||
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
|
||||
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
|
||||
r"%arg3: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
with config.jax_serialization_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
if jtu.device_under_test() == "gpu":
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
x = np.ones((3, 4), dtype=np.float32)
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + _testing_multi_platform_func(x,
|
||||
effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(
|
||||
f_jax,
|
||||
lowering_platforms=("cpu", "tpu")
|
||||
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
|
||||
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
|
||||
r"%arg3: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
||||
res)
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
||||
res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1526,19 +1527,23 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for v in range(export.minimum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version + 1)])
|
||||
def test_ordered_effects_with_donation(self, *, v: int):
|
||||
self.override_serialization_version(v)
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
with config.jax_serialization_version(v):
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version.value)
|
||||
|
||||
def f_jax(x):
|
||||
return testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingOrderedEffect1"
|
||||
)
|
||||
x = np.arange(3, dtype=np.float32)
|
||||
|
||||
f_jax = jax.jit(f_jax, donate_argnums=(0,))
|
||||
exp = export.export(f_jax)(x)
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
|
||||
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
|
||||
def f_jax(x):
|
||||
return testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingOrderedEffect1"
|
||||
)
|
||||
|
||||
f_jax = jax.jit(f_jax, donate_argnums=(0,))
|
||||
exp = export.export(f_jax)(x)
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
|
||||
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user