mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024. This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error. PiperOrigin-RevId: 689708392
This commit is contained in:
parent
0bc70bbd73
commit
9088adda68
10
CHANGELOG.md
10
CHANGELOG.md
@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.36
|
||||
|
||||
* Breaking Changes
|
||||
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
|
||||
or with `enable_xla=False` have been deprecated since July 2024, with
|
||||
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
|
||||
with native serialization will still be supported.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* `jax.experimental.host_callback` has been deprecated since March 2024, with
|
||||
JAX version 0.4.26. Now we removed it.
|
||||
See {jax-issue}`#20385` for a discussion of alternatives.
|
||||
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
|
||||
or with `enable_xla=False` have been deprecated since July 2024, with
|
||||
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
|
||||
with native serialization is still supported.
|
||||
|
||||
* Changes:
|
||||
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
|
||||
|
@ -119,6 +119,10 @@ def _sanitize_scope_name(name):
|
||||
# Line below is different externally and internally.
|
||||
allow_enable_xla_false = lambda: True
|
||||
|
||||
# TODO(b/353437398): Deprecate support for `native_serialization=False`.
|
||||
# Line below is different externally and internally.
|
||||
allow_native_serialization_false = lambda: True
|
||||
|
||||
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
||||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||||
TfVal = Any
|
||||
@ -294,8 +298,8 @@ def convert(fun_jax: Callable,
|
||||
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of
|
||||
the form `e1 >= e2` or `e1 <= e2`.
|
||||
polymorphic_constraints: a sequence of constraints on symbolic dimension
|
||||
expressions, of the form `e1 >= e2` or `e1 <= e2`.
|
||||
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
|
||||
with_gradient: if set (default), add a tf.custom_gradient to the lowered
|
||||
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
||||
@ -332,28 +336,38 @@ def convert(fun_jax: Callable,
|
||||
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
||||
only TensorFlow ops and thus can be called from a TensorFlow program.
|
||||
"""
|
||||
if not enable_xla:
|
||||
if allow_enable_xla_false():
|
||||
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
else:
|
||||
raise ValueError("jax2tf.convert with enable_xla=False is not supported.")
|
||||
|
||||
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
|
||||
if not enable_xla:
|
||||
native_serialization = False
|
||||
else:
|
||||
native_serialization = config.jax2tf_default_native_serialization.value
|
||||
|
||||
if not native_serialization:
|
||||
warnings.warn(
|
||||
"jax2tf.convert with native_serialization=False is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
if native_serialization and not enable_xla:
|
||||
raise ValueError(
|
||||
"native_serialization is not supported with enable_xla=False")
|
||||
if not enable_xla:
|
||||
if allow_enable_xla_false():
|
||||
warnings.warn(
|
||||
"jax2tf.convert with enable_xla=False has been deprecated "
|
||||
"since July 2024.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
if native_serialization:
|
||||
raise ValueError(
|
||||
"native_serialization is not supported with enable_xla=False")
|
||||
else:
|
||||
raise ValueError(
|
||||
"jax2tf.convert with enable_xla=False has been deprecated "
|
||||
"since July 2024 and it is not supported anymore.")
|
||||
|
||||
elif not native_serialization:
|
||||
if allow_native_serialization_false():
|
||||
warnings.warn(
|
||||
"jax2tf.convert with native_serialization=False has been deprecated "
|
||||
"since July 2024.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
else:
|
||||
raise ValueError(
|
||||
"jax2tf.convert with native_serialization=False has been deprecated "
|
||||
"since July 2024 and it is not supported anymore.")
|
||||
|
||||
if not native_serialization and polymorphic_constraints:
|
||||
raise ValueError(
|
||||
@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
|
||||
# failures. Since we are going to turn on native serializaton soon, wait
|
||||
# failures. Since we are going to turn on native serialization soon, wait
|
||||
# until then to turn on this check.
|
||||
# lhs_aval, rhs_aval = _in_avals
|
||||
# if lhs_aval.dtype != rhs_aval.dtype:
|
||||
|
@ -90,7 +90,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
@ -897,7 +897,7 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
@ -1203,7 +1203,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
|
@ -79,7 +79,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
@ -1763,7 +1763,7 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
|
||||
super().setUp()
|
||||
|
||||
@jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||
)
|
||||
def test_simple(self):
|
||||
self.ConvertAndCompare(jnp.sin, 0.7)
|
||||
|
@ -1031,7 +1031,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(f_jax(x), restored_f(x))
|
||||
|
||||
@jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||
)
|
||||
def test_readme_examples(self):
|
||||
"""Some of the examples from the README."""
|
||||
@ -1124,31 +1124,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# JAX with static shapes sees that x.shape[0] != x.shape[1]
|
||||
self.assertEqual(jnp.sum(x45), f2_jax(x45))
|
||||
|
||||
# In graph serialization eager mode, we catch the broken assumption b >= 1
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
re.escape(
|
||||
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
|
||||
"and the specification 'b' (= 4)")):
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False)(x45)
|
||||
|
||||
# In graph serialization graph mode we also catch it (except on TPU, where
|
||||
# the behavior is as for jit_compile=1)
|
||||
|
||||
f2_tf = tf.function(
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
r"Found inconsistency"):
|
||||
_ = f2_tf(x45)
|
||||
|
||||
# We also catch the error with native serialization
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
|
Loading…
x
Reference in New Issue
Block a user