[export] Disable serialization in export_test if flatbuffers is not installed

This allows one to run most of export_test even if flatbuffers
is not installed. Only the serialization and deserialization are
skipped.
This commit is contained in:
George Necula 2024-06-29 15:00:12 +03:00
parent 92ebb533bd
commit cfa3c91c32

View File

@ -45,6 +45,13 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np
# ruff: noqa: F401
try:
import flatbuffers
CAN_SERIALIZE = True
except (ModuleNotFoundError, ImportError):
CAN_SERIALIZE = False
config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()
@ -139,12 +146,15 @@ def _testing_multi_platform_fun_expected(x,
def get_exported(fun: Callable, vjp_order=0,
**export_kwargs):
**export_kwargs) -> Callable[[...], export.Exported]:
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
if CAN_SERIALIZE:
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
else:
return exp
return serde_exported
@ -234,6 +244,8 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.ignore_warning(category=DeprecationWarning,
message="The jax.experimental.export module is deprecated")
def test_export_experimental_back_compat(self):
if not CAN_SERIALIZE:
self.skipTest("serialization disabled")
from jax.experimental import export
# Can export a lambda, without jit
exp = export.export(lambda x: jnp.sin(x))(.1)
@ -1328,8 +1340,9 @@ class JaxExportTest(jtu.JaxTestCase):
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
_ = exp.serialize(vjp_order=1)
_ = exp_rev.serialize(vjp_order=1)
if CAN_SERIALIZE:
_ = exp.serialize(vjp_order=1)
_ = exp_rev.serialize(vjp_order=1)
g = jax.grad(exp_rev.call)(input_rev)
g_rev = jax.grad(exp.call)(input)
@ -1725,6 +1738,9 @@ class JaxExportTest(jtu.JaxTestCase):
)
])
def test_ordered_effects_error(self, *, name: str, expect_error: str):
if not CAN_SERIALIZE:
# These errors arise during serialization
self.skipTest("serialization is disabled")
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x):
return 10. + _testing_multi_platform_func(