mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[export] Disable serialization in export_test if flatbuffers is not installed
This allows one to run most of export_test even if flatbuffers is not installed. Only the serialization and deserialization are skipped.
This commit is contained in:
parent
92ebb533bd
commit
cfa3c91c32
@ -45,6 +45,13 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ruff: noqa: F401
|
||||
try:
|
||||
import flatbuffers
|
||||
CAN_SERIALIZE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
CAN_SERIALIZE = False
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
@ -139,12 +146,15 @@ def _testing_multi_platform_fun_expected(x,
|
||||
|
||||
|
||||
def get_exported(fun: Callable, vjp_order=0,
|
||||
**export_kwargs):
|
||||
**export_kwargs) -> Callable[[...], export.Exported]:
|
||||
"""Like export.export but with serialization + deserialization."""
|
||||
def serde_exported(*fun_args, **fun_kwargs):
|
||||
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
|
||||
serialized = exp.serialize(vjp_order=vjp_order)
|
||||
return export.deserialize(serialized)
|
||||
if CAN_SERIALIZE:
|
||||
serialized = exp.serialize(vjp_order=vjp_order)
|
||||
return export.deserialize(serialized)
|
||||
else:
|
||||
return exp
|
||||
return serde_exported
|
||||
|
||||
|
||||
@ -234,6 +244,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="The jax.experimental.export module is deprecated")
|
||||
def test_export_experimental_back_compat(self):
|
||||
if not CAN_SERIALIZE:
|
||||
self.skipTest("serialization disabled")
|
||||
from jax.experimental import export
|
||||
# Can export a lambda, without jit
|
||||
exp = export.export(lambda x: jnp.sin(x))(.1)
|
||||
@ -1328,8 +1340,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
|
||||
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
|
||||
|
||||
_ = exp.serialize(vjp_order=1)
|
||||
_ = exp_rev.serialize(vjp_order=1)
|
||||
if CAN_SERIALIZE:
|
||||
_ = exp.serialize(vjp_order=1)
|
||||
_ = exp_rev.serialize(vjp_order=1)
|
||||
|
||||
g = jax.grad(exp_rev.call)(input_rev)
|
||||
g_rev = jax.grad(exp.call)(input)
|
||||
@ -1725,6 +1738,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
)
|
||||
])
|
||||
def test_ordered_effects_error(self, *, name: str, expect_error: str):
|
||||
if not CAN_SERIALIZE:
|
||||
# These errors arise during serialization
|
||||
self.skipTest("serialization is disabled")
|
||||
x = np.ones((3, 4), dtype=np.float32)
|
||||
def f_jax(x):
|
||||
return 10. + _testing_multi_platform_func(
|
||||
|
Loading…
x
Reference in New Issue
Block a user