mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18956 from gnecula:export_effects
PiperOrigin-RevId: 591134940
This commit is contained in:
commit
6cd7adac99
@ -386,17 +386,24 @@ def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
|
||||
|
||||
|
||||
def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int:
|
||||
# TODO(necula): for now serialize just the name of the class
|
||||
try:
|
||||
_ = eff.__class__()
|
||||
except:
|
||||
eff_replica = eff.__class__()
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"serializing effect {eff} that does not have a nullary class"
|
||||
" constructor"
|
||||
f"Effect {eff} must have a nullary constructor to be serializable"
|
||||
)
|
||||
try:
|
||||
hash_eff = hash(eff)
|
||||
hash_eff_replica = hash(eff_replica)
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"Effect {eff} must be hashable to be serializable"
|
||||
)
|
||||
if eff != eff_replica or hash_eff != hash_eff_replica:
|
||||
raise NotImplementedError(
|
||||
f"Effect {eff} must have a nullary class constructor that produces an "
|
||||
"equal effect object."
|
||||
)
|
||||
# TODO: fix the effects serialization and deserialization, to ensure that
|
||||
# upon deserialization we reconstruct an effect that compares equal to the
|
||||
# one that was serialized.
|
||||
effect_type_name = str(eff.__class__)
|
||||
effect_type_name_offset = builder.CreateString(effect_type_name)
|
||||
ser_flatbuf.EffectStart(builder)
|
||||
|
@ -26,6 +26,7 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@ -35,20 +36,16 @@ from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@ -376,6 +373,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
|
||||
|
||||
|
||||
# Mark the effectful instances of call_tf
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CallTfEffect(effects.Effect):
|
||||
__str__ = lambda _: "CallTfEffect"
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
@ -57,19 +58,35 @@ def tearDownModule():
|
||||
prev_xla_flags()
|
||||
|
||||
### Setup for testing lowering with effects
|
||||
class TestingOrderedEffect1(effects.Effect):
|
||||
__str__ = lambda _: "TestingOrderedEffect1"
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ForTestingOrderedEffect1(effects.Effect):
|
||||
pass
|
||||
|
||||
class TestingOrderedEffect2(effects.Effect):
|
||||
__str__ = lambda _: "TestingOrderedEffect2"
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ForTestingOrderedEffect2(effects.Effect):
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ForTestingUnorderedEffect1(effects.Effect):
|
||||
pass
|
||||
|
||||
|
||||
class ForTestingOrderedEffect4NoNullary(effects.Effect):
|
||||
def __init__(self, _):
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass(eq=False)
|
||||
class ForTestingOrderedEffect5NoEq(effects.Effect):
|
||||
pass
|
||||
|
||||
class TestingUnorderedEffect1(effects.Effect):
|
||||
__str__ = lambda _: "TestingUnorderedEffect1"
|
||||
|
||||
_testing_effects = dict(
|
||||
TestingOrderedEffect1=TestingOrderedEffect1(),
|
||||
TestingOrderedEffect2=TestingOrderedEffect2(),
|
||||
TestingUnorderedEffect1=TestingUnorderedEffect1())
|
||||
ForTestingOrderedEffect1=ForTestingOrderedEffect1(),
|
||||
ForTestingOrderedEffect2=ForTestingOrderedEffect2(),
|
||||
ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(),
|
||||
ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42),
|
||||
ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(),
|
||||
)
|
||||
# Register the effects
|
||||
for effect in _testing_effects.values():
|
||||
effect_class = effect.__class__
|
||||
@ -1015,23 +1032,20 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Test also the calling convention for inner functions
|
||||
def f_jax_inner(x):
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingUnorderedEffect1"))
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
|
||||
return (
|
||||
10. +
|
||||
jax.jit(f_jax_inner)(x) +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2")
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
|
||||
)
|
||||
|
||||
# TODO(necula): at the moment serializing and deserializing effects breaks
|
||||
# the effect equality, and this results in this test failing. So, for now
|
||||
# we disable the serization round-trip
|
||||
exp = export.export(f_jax)(x) # get_exported(f_jax)(x)
|
||||
exp = get_exported(f_jax)(x)
|
||||
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in exp.ordered_effects))
|
||||
self.assertEqual(["TestingUnorderedEffect1"],
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
[str(e) for e in exp.unordered_effects])
|
||||
else:
|
||||
self.assertEqual([], [str(e) for e in exp.ordered_effects])
|
||||
@ -1074,19 +1088,19 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f_outer(x):
|
||||
return (
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="TestingOrderedEffect2") +
|
||||
x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="TestingUnorderedEffect1") +
|
||||
x, effect_class_name="ForTestingUnorderedEffect1") +
|
||||
export.call_exported(exp)(x))
|
||||
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
self.assertEqual(["TestingOrderedEffect2"],
|
||||
self.assertEqual(["ForTestingOrderedEffect2()"],
|
||||
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
|
||||
else:
|
||||
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
|
||||
self.assertEqual(["TestingUnorderedEffect1"],
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
|
||||
|
||||
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
||||
@ -1106,7 +1120,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.override_serialization_version(v)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1")
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
|
||||
export.symbolic_shape("b2, b1"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
@ -1150,7 +1164,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.ones((3, 4), dtype=np.float32)
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + _testing_multi_platform_func(x,
|
||||
effect_class_name="TestingOrderedEffect1")
|
||||
effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(
|
||||
f_jax,
|
||||
lowering_platforms=("cpu", "tpu")
|
||||
@ -1196,7 +1210,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def f_jax(x):
|
||||
return testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="TestingOrderedEffect1"
|
||||
x, effect_class_name="ForTestingOrderedEffect1"
|
||||
)
|
||||
|
||||
f_jax = jax.jit(f_jax, donate_argnums=(0,))
|
||||
@ -1209,6 +1223,24 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
|
||||
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(name=name, expect_error=expect_error)
|
||||
# name is the suffix for event name: ForTestingOrderedEffectxxx
|
||||
for name, expect_error in (
|
||||
("4NoNullary", "must have a nullary constructor"),
|
||||
("5NoEq", "must have a nullary class constructor that produces an "
|
||||
"equal effect object"),
|
||||
)
|
||||
])
|
||||
def test_ordered_effects_error(self, *, name: str, expect_error: str):
|
||||
x = np.ones((3, 4), dtype=np.float32)
|
||||
def f_jax(x):
|
||||
return 10. + _testing_multi_platform_func(
|
||||
x,
|
||||
effect_class_name="ForTestingOrderedEffect" + name)
|
||||
with self.assertRaisesRegex(Exception, expect_error):
|
||||
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user