mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add checks that we do not see unexpected lowered.compiler_args
Some of those compile_args change the semantics and the calling convention for the lowered module. We want to be explicit about the ones that we are handling. PiperOrigin-RevId: 521419681
This commit is contained in:
parent
b0a6cdbf24
commit
2ce78ac9a8
@ -16,9 +16,8 @@
|
||||
This module is used with jax2tf, but should have no TensorFlow dependencies.
|
||||
"""
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import re
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -117,6 +116,50 @@ def serialize_native(fun_jax: Callable,
|
||||
*arg_specs_jax,
|
||||
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
|
||||
|
||||
if not isinstance(lowered, pxla.MeshComputation):
|
||||
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
|
||||
|
||||
# Check that we do not see new compile_args. When we add a compile_args it is
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = ["backend", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"host_callbacks", "keepalive", "pmap_nreps", "committed", "device_assignment"]
|
||||
for compile_arg in lowered.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
||||
# We have not implemented support for some of the compile_args.
|
||||
not_implemented_msgs = []
|
||||
for compile_arg, check_value, err_msg in (
|
||||
("spmd_lowering", lambda v: v, "True"),
|
||||
("auto_spmd_lowering", lambda v: not v, "False"),
|
||||
("tuple_args", lambda v: not v, "False"),
|
||||
# Used for debug(ordered=True), changes the calling convention, but will
|
||||
# also set keepalive to non-empty.
|
||||
("ordered_effects", lambda v: not v, "empty"),
|
||||
# unordered_effects do not change the calling convention. Those from
|
||||
# jax.debug will also result in keepalive being non-empty and unsupported
|
||||
# custom calls. The CallTfEffect is an exception, but we want to allow
|
||||
# that one.
|
||||
("unordered_effects", lambda v: True, "N/A"),
|
||||
# used for TPU jax.debug, send/recv. Not supported yet.
|
||||
("host_callbacks", lambda v: not v, "empty"),
|
||||
# used on all platforms for callbacks. Not supported yet.
|
||||
("keepalive", lambda v: not v, "empty"),
|
||||
("pmap_nreps", lambda v: v == 1, "1"),
|
||||
):
|
||||
if compile_arg in lowered.compile_args:
|
||||
if not check_value(lowered.compile_args[compile_arg]):
|
||||
not_implemented_msgs.append(
|
||||
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
|
||||
if not_implemented_msgs:
|
||||
raise NotImplementedError(
|
||||
"serialization error, unimplemented lowered.compile_args:\n" +
|
||||
"\n".join(not_implemented_msgs))
|
||||
|
||||
mlir_module = lowered.stablehlo()
|
||||
if "kept_var_idx" in lowered.compile_args:
|
||||
module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"]))
|
||||
|
@ -1604,6 +1604,23 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
r"The current platform .* is not among the platforms required by the module: \[TPU\]"):
|
||||
f_grad_tf(x_v)
|
||||
|
||||
def test_effects_error(self):
|
||||
def f_jax(x):
|
||||
jax.debug.print("{}", x)
|
||||
return jnp.sin(x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"keepalive must be empty"):
|
||||
jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.))
|
||||
|
||||
def f_ordered_jax(x):
|
||||
jax.debug.print("{}", x, ordered=True)
|
||||
return jnp.sin(x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"keepalive must be empty"):
|
||||
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
|
||||
|
||||
|
||||
def get_serialized_computation(
|
||||
f_jax: Callable,
|
||||
|
Loading…
x
Reference in New Issue
Block a user