[jax2tf] Add checks that we do not see unexpected lowered.compiler_args

Some of those compile_args change the semantics and the calling convention
for the lowered module. We want to be explicit about the ones that we
are handling.

PiperOrigin-RevId: 521419681
This commit is contained in:
George Necula 2023-04-03 04:12:50 -07:00 committed by jax authors
parent b0a6cdbf24
commit 2ce78ac9a8
2 changed files with 62 additions and 2 deletions

View File

@ -16,9 +16,8 @@
This module is used with jax2tf, but should have no TensorFlow dependencies.
"""
import dataclasses
from functools import partial
import re
from typing import Callable, Dict, List, Optional, Sequence, Set, Union
from typing import Callable, List, Optional, Sequence, Union
from absl import logging
@ -117,6 +116,50 @@ def serialize_native(fun_jax: Callable,
*arg_specs_jax,
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
if not isinstance(lowered, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = ["backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"host_callbacks", "keepalive", "pmap_nreps", "committed", "device_assignment"]
for compile_arg in lowered.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
# We have not implemented support for some of the compile_args.
not_implemented_msgs = []
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
("tuple_args", lambda v: not v, "False"),
# Used for debug(ordered=True), changes the calling convention, but will
# also set keepalive to non-empty.
("ordered_effects", lambda v: not v, "empty"),
# unordered_effects do not change the calling convention. Those from
# jax.debug will also result in keepalive being non-empty and unsupported
# custom calls. The CallTfEffect is an exception, but we want to allow
# that one.
("unordered_effects", lambda v: True, "N/A"),
# used for TPU jax.debug, send/recv. Not supported yet.
("host_callbacks", lambda v: not v, "empty"),
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
):
if compile_arg in lowered.compile_args:
if not check_value(lowered.compile_args[compile_arg]):
not_implemented_msgs.append(
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
if not_implemented_msgs:
raise NotImplementedError(
"serialization error, unimplemented lowered.compile_args:\n" +
"\n".join(not_implemented_msgs))
mlir_module = lowered.stablehlo()
if "kept_var_idx" in lowered.compile_args:
module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"]))

View File

@ -1604,6 +1604,23 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
r"The current platform .* is not among the platforms required by the module: \[TPU\]"):
f_grad_tf(x_v)
def test_effects_error(self):
def f_jax(x):
jax.debug.print("{}", x)
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"keepalive must be empty"):
jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.))
def f_ordered_jax(x):
jax.debug.print("{}", x, ordered=True)
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"keepalive must be empty"):
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
def get_serialized_computation(
f_jax: Callable,