[jax2tf] Disable jax2tf with non-native serialization.

jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
This commit is contained in:
George Necula 2024-10-25 02:30:17 -07:00 committed by jax authors
parent 0bc70bbd73
commit 9088adda68
5 changed files with 50 additions and 51 deletions

View File

@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.36
* Breaking Changes
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.
## jax 0.4.35 (Oct 22, 2024)
* Breaking Changes
@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See {jax-issue}`#20385` for a discussion of alternatives.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization is still supported.
* Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT

View File

@ -119,6 +119,10 @@ def _sanitize_scope_name(name):
# Line below is different externally and internally.
allow_enable_xla_false = lambda: True
# TODO(b/353437398): Deprecate support for `native_serialization=False`.
# Line below is different externally and internally.
allow_native_serialization_false = lambda: True
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
@ -294,8 +298,8 @@ def convert(fun_jax: Callable,
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of
the form `e1 >= e2` or `e1 <= e2`.
polymorphic_constraints: a sequence of constraints on symbolic dimension
expressions, of the form `e1 >= e2` or `e1 <= e2`.
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
@ -332,28 +336,38 @@ def convert(fun_jax: Callable,
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if not enable_xla:
if allow_enable_xla_false():
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError("jax2tf.convert with enable_xla=False is not supported.")
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization.value
if not native_serialization:
warnings.warn(
"jax2tf.convert with native_serialization=False is deprecated.",
DeprecationWarning,
stacklevel=2)
if native_serialization and not enable_xla:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
if not enable_xla:
if allow_enable_xla_false():
warnings.warn(
"jax2tf.convert with enable_xla=False has been deprecated "
"since July 2024.",
DeprecationWarning,
stacklevel=2)
if native_serialization:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
else:
raise ValueError(
"jax2tf.convert with enable_xla=False has been deprecated "
"since July 2024 and it is not supported anymore.")
elif not native_serialization:
if allow_native_serialization_false():
warnings.warn(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024 and it is not supported anymore.")
if not native_serialization and polymorphic_constraints:
raise ValueError(
@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait
# failures. Since we are going to turn on native serialization soon, wait
# until then to turn on this check.
# lhs_aval, rhs_aval = _in_avals
# if lhs_aval.dtype != rhs_aval.dtype:

View File

@ -90,7 +90,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
@ -897,7 +897,7 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
@ -1203,7 +1203,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)

View File

@ -79,7 +79,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
self.warning_ctx.__enter__()
@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
self.warning_ctx.__enter__()
@ -1763,7 +1763,7 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
super().setUp()
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7)

View File

@ -1031,7 +1031,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(f_jax(x), restored_f(x))
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
def test_readme_examples(self):
"""Some of the examples from the README."""
@ -1124,31 +1124,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# JAX with static shapes sees that x.shape[0] != x.shape[1]
self.assertEqual(jnp.sum(x45), f2_jax(x45))
# In graph serialization eager mode, we catch the broken assumption b >= 1
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape(
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
"and the specification 'b' (= 4)")):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)
# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)
f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
if jtu.test_device_matches(["tpu"]):
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r"Found inconsistency"):
_ = f2_tf(x45)
# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,