Merge pull request #18956 from gnecula:export_effects

PiperOrigin-RevId: 591134940
This commit is contained in:
jax authors 2023-12-14 21:03:11 -08:00
commit 6cd7adac99
3 changed files with 76 additions and 39 deletions

View File

@ -386,17 +386,24 @@ def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int:
# TODO(necula): for now serialize just the name of the class
try:
_ = eff.__class__()
except:
eff_replica = eff.__class__()
except Exception:
raise NotImplementedError(
f"serializing effect {eff} that does not have a nullary class"
" constructor"
f"Effect {eff} must have a nullary constructor to be serializable"
)
try:
hash_eff = hash(eff)
hash_eff_replica = hash(eff_replica)
except Exception:
raise NotImplementedError(
f"Effect {eff} must be hashable to be serializable"
)
if eff != eff_replica or hash_eff != hash_eff_replica:
raise NotImplementedError(
f"Effect {eff} must have a nullary class constructor that produces an "
"equal effect object."
)
# TODO: fix the effects serialization and deserialization, to ensure that
# upon deserialization we reconstruct an effect that compares equal to the
# one that was serialized.
effect_type_name = str(eff.__class__)
effect_type_name_offset = builder.CreateString(effect_type_name)
ser_flatbuf.EffectStart(builder)

View File

@ -26,6 +26,7 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import functools
from typing import Any, Callable, Optional
@ -35,20 +36,16 @@ from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
from jax.interpreters import mlir
from jax.interpreters import xla
import numpy as np
import tensorflow as tf
@ -376,6 +373,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
# Mark the effectful instances of call_tf
@dataclasses.dataclass(frozen=True)
class CallTfEffect(effects.Effect):
__str__ = lambda _: "CallTfEffect"

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import contextlib
import dataclasses
import functools
import logging
import math
@ -57,19 +58,35 @@ def tearDownModule():
prev_xla_flags()
### Setup for testing lowering with effects
class TestingOrderedEffect1(effects.Effect):
__str__ = lambda _: "TestingOrderedEffect1"
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect1(effects.Effect):
pass
class TestingOrderedEffect2(effects.Effect):
__str__ = lambda _: "TestingOrderedEffect2"
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect2(effects.Effect):
pass
@dataclasses.dataclass(frozen=True)
class ForTestingUnorderedEffect1(effects.Effect):
pass
class ForTestingOrderedEffect4NoNullary(effects.Effect):
def __init__(self, _):
pass
@dataclasses.dataclass(eq=False)
class ForTestingOrderedEffect5NoEq(effects.Effect):
pass
class TestingUnorderedEffect1(effects.Effect):
__str__ = lambda _: "TestingUnorderedEffect1"
_testing_effects = dict(
TestingOrderedEffect1=TestingOrderedEffect1(),
TestingOrderedEffect2=TestingOrderedEffect2(),
TestingUnorderedEffect1=TestingUnorderedEffect1())
ForTestingOrderedEffect1=ForTestingOrderedEffect1(),
ForTestingOrderedEffect2=ForTestingOrderedEffect2(),
ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(),
ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42),
ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(),
)
# Register the effects
for effect in _testing_effects.values():
effect_class = effect.__class__
@ -1015,23 +1032,20 @@ class JaxExportTest(jtu.JaxTestCase):
# Test also the calling convention for inner functions
def f_jax_inner(x):
return (
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingUnorderedEffect1"))
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
return (
10. +
jax.jit(f_jax_inner)(x) +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2")
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
)
# TODO(necula): at the moment serializing and deserializing effects breaks
# the effect equality, and this results in this test failing. So, for now
# we disable the serization round-trip
exp = export.export(f_jax)(x) # get_exported(f_jax)(x)
exp = get_exported(f_jax)(x)
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["TestingUnorderedEffect1"],
self.assertEqual(["ForTestingUnorderedEffect1()"],
[str(e) for e in exp.unordered_effects])
else:
self.assertEqual([], [str(e) for e in exp.ordered_effects])
@ -1074,19 +1088,19 @@ class JaxExportTest(jtu.JaxTestCase):
def f_outer(x):
return (
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingOrderedEffect2") +
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingUnorderedEffect1") +
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect2()"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else:
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
self.assertEqual(["TestingUnorderedEffect1"],
self.assertEqual(["ForTestingUnorderedEffect1()"],
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
mlir_outer_module_str = str(lowered_outer.compiler_ir())
@ -1106,7 +1120,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.override_serialization_version(v)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1")
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
export.symbolic_shape("b2, b1"), x.dtype))
mlir_module_str = str(exp.mlir_module())
@ -1150,7 +1164,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x): # x: f32[b1, b2]
return 10. + _testing_multi_platform_func(x,
effect_class_name="TestingOrderedEffect1")
effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(
f_jax,
lowering_platforms=("cpu", "tpu")
@ -1196,7 +1210,7 @@ class JaxExportTest(jtu.JaxTestCase):
def f_jax(x):
return testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingOrderedEffect1"
x, effect_class_name="ForTestingOrderedEffect1"
)
f_jax = jax.jit(f_jax, donate_argnums=(0,))
@ -1209,6 +1223,24 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
@jtu.parameterized_filterable(
kwargs=[
dict(name=name, expect_error=expect_error)
# name is the suffix for event name: ForTestingOrderedEffectxxx
for name, expect_error in (
("4NoNullary", "must have a nullary constructor"),
("5NoEq", "must have a nullary class constructor that produces an "
"equal effect object"),
)
])
def test_ordered_effects_error(self, *, name: str, expect_error: str):
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x):
return 10. + _testing_multi_platform_func(
x,
effect_class_name="ForTestingOrderedEffect" + name)
with self.assertRaisesRegex(Exception, expect_error):
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())