[jax2tf] Disable jax2tf with non-native serialization.

jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
This commit is contained in:
George Necula 2024-10-25 02:30:17 -07:00 committed by jax authors
parent 0bc70bbd73
commit 9088adda68
5 changed files with 50 additions and 51 deletions

View File

@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.36 ## jax 0.4.36
* Breaking Changes
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.
## jax 0.4.35 (Oct 22, 2024) ## jax 0.4.35 (Oct 22, 2024)
* Breaking Changes * Breaking Changes
@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.experimental.host_callback` has been deprecated since March 2024, with * `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it. JAX version 0.4.26. Now we removed it.
See {jax-issue}`#20385` for a discussion of alternatives. See {jax-issue}`#20385` for a discussion of alternatives.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization is still supported.
* Changes: * Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT * `jax.lax.FftType` was introduced as a public name for the enum of FFT

View File

@ -119,6 +119,10 @@ def _sanitize_scope_name(name):
# Line below is different externally and internally. # Line below is different externally and internally.
allow_enable_xla_false = lambda: True allow_enable_xla_false = lambda: True
# TODO(b/353437398): Deprecate support for `native_serialization=False`.
# Line below is different externally and internally.
allow_native_serialization_false = lambda: True
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any TfVal = Any
@ -294,8 +298,8 @@ def convert(fun_jax: Callable,
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details. for more details.
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of polymorphic_constraints: a sequence of constraints on symbolic dimension
the form `e1 >= e2` or `e1 <= e2`. expressions, of the form `e1 >= e2` or `e1 <= e2`.
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
@ -332,28 +336,38 @@ def convert(fun_jax: Callable,
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program. only TensorFlow ops and thus can be called from a TensorFlow program.
""" """
if not enable_xla:
if allow_enable_xla_false():
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError("jax2tf.convert with enable_xla=False is not supported.")
if native_serialization is DEFAULT_NATIVE_SERIALIZATION: if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
if not enable_xla: if not enable_xla:
native_serialization = False native_serialization = False
else: else:
native_serialization = config.jax2tf_default_native_serialization.value native_serialization = config.jax2tf_default_native_serialization.value
if not native_serialization: if not enable_xla:
warnings.warn( if allow_enable_xla_false():
"jax2tf.convert with native_serialization=False is deprecated.", warnings.warn(
DeprecationWarning, "jax2tf.convert with enable_xla=False has been deprecated "
stacklevel=2) "since July 2024.",
if native_serialization and not enable_xla: DeprecationWarning,
raise ValueError( stacklevel=2)
"native_serialization is not supported with enable_xla=False") if native_serialization:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
else:
raise ValueError(
"jax2tf.convert with enable_xla=False has been deprecated "
"since July 2024 and it is not supported anymore.")
elif not native_serialization:
if allow_native_serialization_false():
warnings.warn(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024 and it is not supported anymore.")
if not native_serialization and polymorphic_constraints: if not native_serialization and polymorphic_constraints:
raise ValueError( raise ValueError(
@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
_out_aval: core.ShapedArray): _out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" """Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# TODO(b/293247337): we ought to turn on this safety check, but this leads to # TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait # failures. Since we are going to turn on native serialization soon, wait
# until then to turn on this check. # until then to turn on this check.
# lhs_aval, rhs_aval = _in_avals # lhs_aval, rhs_aval = _in_avals
# if lhs_aval.dtype != rhs_aval.dtype: # if lhs_aval.dtype != rhs_aval.dtype:

View File

@ -90,7 +90,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning( self.warning_ctx = jtu.ignore_warning(
message=( message=(
"(jax2tf.convert with native_serialization=False is deprecated" "(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)" "|Calling from_dlpack with a DLPack tensor is deprecated)"
) )
) )
@ -897,7 +897,7 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning( self.warning_ctx = jtu.ignore_warning(
message=( message=(
"(jax2tf.convert with native_serialization=False is deprecated" "(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)" "|Calling from_dlpack with a DLPack tensor is deprecated)"
) )
) )
@ -1203,7 +1203,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning( self.warning_ctx = jtu.ignore_warning(
message=( message=(
"(jax2tf.convert with native_serialization=False is deprecated" "(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)" "|Calling from_dlpack with a DLPack tensor is deprecated)"
) )
) )

View File

@ -79,7 +79,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning( self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated" message="jax2tf.convert with native_serialization=False has been deprecated"
) )
self.warning_ctx.__enter__() self.warning_ctx.__enter__()
@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning( self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated" message="jax2tf.convert with native_serialization=False has been deprecated"
) )
self.warning_ctx.__enter__() self.warning_ctx.__enter__()
@ -1763,7 +1763,7 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
super().setUp() super().setUp()
@jtu.ignore_warning( @jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated" message="jax2tf.convert with native_serialization=False has been deprecated"
) )
def test_simple(self): def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7) self.ConvertAndCompare(jnp.sin, 0.7)

View File

@ -1031,7 +1031,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(f_jax(x), restored_f(x)) self.assertAllClose(f_jax(x), restored_f(x))
@jtu.ignore_warning( @jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated" message="jax2tf.convert with native_serialization=False has been deprecated"
) )
def test_readme_examples(self): def test_readme_examples(self):
"""Some of the examples from the README.""" """Some of the examples from the README."""
@ -1124,31 +1124,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# JAX with static shapes sees that x.shape[0] != x.shape[1] # JAX with static shapes sees that x.shape[0] != x.shape[1]
self.assertEqual(jnp.sum(x45), f2_jax(x45)) self.assertEqual(jnp.sum(x45), f2_jax(x45))
# In graph serialization eager mode, we catch the broken assumption b >= 1
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape(
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
"and the specification 'b' (= 4)")):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)
# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)
f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
if jtu.test_device_matches(["tpu"]):
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r"Found inconsistency"):
_ = f2_tf(x45)
# We also catch the error with native serialization # We also catch the error with native serialization
with self.assertRaisesRegex( with self.assertRaisesRegex(
tf.errors.InvalidArgumentError, tf.errors.InvalidArgumentError,