mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024. This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error. PiperOrigin-RevId: 689708392
This commit is contained in:
parent
0bc70bbd73
commit
9088adda68
10
CHANGELOG.md
10
CHANGELOG.md
@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
|
|
||||||
## jax 0.4.36
|
## jax 0.4.36
|
||||||
|
|
||||||
|
* Breaking Changes
|
||||||
|
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
|
||||||
|
or with `enable_xla=False` have been deprecated since July 2024, with
|
||||||
|
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
|
||||||
|
with native serialization will still be supported.
|
||||||
|
|
||||||
## jax 0.4.35 (Oct 22, 2024)
|
## jax 0.4.35 (Oct 22, 2024)
|
||||||
|
|
||||||
* Breaking Changes
|
* Breaking Changes
|
||||||
@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
* `jax.experimental.host_callback` has been deprecated since March 2024, with
|
* `jax.experimental.host_callback` has been deprecated since March 2024, with
|
||||||
JAX version 0.4.26. Now we removed it.
|
JAX version 0.4.26. Now we removed it.
|
||||||
See {jax-issue}`#20385` for a discussion of alternatives.
|
See {jax-issue}`#20385` for a discussion of alternatives.
|
||||||
|
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
|
||||||
|
or with `enable_xla=False` have been deprecated since July 2024, with
|
||||||
|
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
|
||||||
|
with native serialization is still supported.
|
||||||
|
|
||||||
* Changes:
|
* Changes:
|
||||||
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
|
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
|
||||||
|
@ -119,6 +119,10 @@ def _sanitize_scope_name(name):
|
|||||||
# Line below is different externally and internally.
|
# Line below is different externally and internally.
|
||||||
allow_enable_xla_false = lambda: True
|
allow_enable_xla_false = lambda: True
|
||||||
|
|
||||||
|
# TODO(b/353437398): Deprecate support for `native_serialization=False`.
|
||||||
|
# Line below is different externally and internally.
|
||||||
|
allow_native_serialization_false = lambda: True
|
||||||
|
|
||||||
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
||||||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||||||
TfVal = Any
|
TfVal = Any
|
||||||
@ -294,8 +298,8 @@ def convert(fun_jax: Callable,
|
|||||||
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||||
for more details.
|
for more details.
|
||||||
|
|
||||||
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of
|
polymorphic_constraints: a sequence of constraints on symbolic dimension
|
||||||
the form `e1 >= e2` or `e1 <= e2`.
|
expressions, of the form `e1 >= e2` or `e1 <= e2`.
|
||||||
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
|
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
|
||||||
with_gradient: if set (default), add a tf.custom_gradient to the lowered
|
with_gradient: if set (default), add a tf.custom_gradient to the lowered
|
||||||
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
||||||
@ -332,28 +336,38 @@ def convert(fun_jax: Callable,
|
|||||||
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
||||||
only TensorFlow ops and thus can be called from a TensorFlow program.
|
only TensorFlow ops and thus can be called from a TensorFlow program.
|
||||||
"""
|
"""
|
||||||
if not enable_xla:
|
|
||||||
if allow_enable_xla_false():
|
|
||||||
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2)
|
|
||||||
else:
|
|
||||||
raise ValueError("jax2tf.convert with enable_xla=False is not supported.")
|
|
||||||
|
|
||||||
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
|
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
|
||||||
if not enable_xla:
|
if not enable_xla:
|
||||||
native_serialization = False
|
native_serialization = False
|
||||||
else:
|
else:
|
||||||
native_serialization = config.jax2tf_default_native_serialization.value
|
native_serialization = config.jax2tf_default_native_serialization.value
|
||||||
|
|
||||||
if not native_serialization:
|
if not enable_xla:
|
||||||
warnings.warn(
|
if allow_enable_xla_false():
|
||||||
"jax2tf.convert with native_serialization=False is deprecated.",
|
warnings.warn(
|
||||||
DeprecationWarning,
|
"jax2tf.convert with enable_xla=False has been deprecated "
|
||||||
stacklevel=2)
|
"since July 2024.",
|
||||||
if native_serialization and not enable_xla:
|
DeprecationWarning,
|
||||||
raise ValueError(
|
stacklevel=2)
|
||||||
"native_serialization is not supported with enable_xla=False")
|
if native_serialization:
|
||||||
|
raise ValueError(
|
||||||
|
"native_serialization is not supported with enable_xla=False")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"jax2tf.convert with enable_xla=False has been deprecated "
|
||||||
|
"since July 2024 and it is not supported anymore.")
|
||||||
|
|
||||||
|
elif not native_serialization:
|
||||||
|
if allow_native_serialization_false():
|
||||||
|
warnings.warn(
|
||||||
|
"jax2tf.convert with native_serialization=False has been deprecated "
|
||||||
|
"since July 2024.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"jax2tf.convert with native_serialization=False has been deprecated "
|
||||||
|
"since July 2024 and it is not supported anymore.")
|
||||||
|
|
||||||
if not native_serialization and polymorphic_constraints:
|
if not native_serialization and polymorphic_constraints:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
|||||||
_out_aval: core.ShapedArray):
|
_out_aval: core.ShapedArray):
|
||||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||||
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
|
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
|
||||||
# failures. Since we are going to turn on native serializaton soon, wait
|
# failures. Since we are going to turn on native serialization soon, wait
|
||||||
# until then to turn on this check.
|
# until then to turn on this check.
|
||||||
# lhs_aval, rhs_aval = _in_avals
|
# lhs_aval, rhs_aval = _in_avals
|
||||||
# if lhs_aval.dtype != rhs_aval.dtype:
|
# if lhs_aval.dtype != rhs_aval.dtype:
|
||||||
|
@ -90,7 +90,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
self.warning_ctx = jtu.ignore_warning(
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
message=(
|
message=(
|
||||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -897,7 +897,7 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
self.warning_ctx = jtu.ignore_warning(
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
message=(
|
message=(
|
||||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1203,7 +1203,7 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
self.warning_ctx = jtu.ignore_warning(
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
message=(
|
message=(
|
||||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
"(jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -79,7 +79,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.warning_ctx = jtu.ignore_warning(
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
)
|
)
|
||||||
self.warning_ctx.__enter__()
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.warning_ctx = jtu.ignore_warning(
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
)
|
)
|
||||||
self.warning_ctx.__enter__()
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
@ -1763,7 +1763,7 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@jtu.ignore_warning(
|
@jtu.ignore_warning(
|
||||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
)
|
)
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
self.ConvertAndCompare(jnp.sin, 0.7)
|
self.ConvertAndCompare(jnp.sin, 0.7)
|
||||||
|
@ -1031,7 +1031,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
self.assertAllClose(f_jax(x), restored_f(x))
|
self.assertAllClose(f_jax(x), restored_f(x))
|
||||||
|
|
||||||
@jtu.ignore_warning(
|
@jtu.ignore_warning(
|
||||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
message="jax2tf.convert with native_serialization=False has been deprecated"
|
||||||
)
|
)
|
||||||
def test_readme_examples(self):
|
def test_readme_examples(self):
|
||||||
"""Some of the examples from the README."""
|
"""Some of the examples from the README."""
|
||||||
@ -1124,31 +1124,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
# JAX with static shapes sees that x.shape[0] != x.shape[1]
|
# JAX with static shapes sees that x.shape[0] != x.shape[1]
|
||||||
self.assertEqual(jnp.sum(x45), f2_jax(x45))
|
self.assertEqual(jnp.sum(x45), f2_jax(x45))
|
||||||
|
|
||||||
# In graph serialization eager mode, we catch the broken assumption b >= 1
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
tf.errors.InvalidArgumentError,
|
|
||||||
re.escape(
|
|
||||||
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
|
|
||||||
"and the specification 'b' (= 4)")):
|
|
||||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
|
||||||
native_serialization=False)(x45)
|
|
||||||
|
|
||||||
# In graph serialization graph mode we also catch it (except on TPU, where
|
|
||||||
# the behavior is as for jit_compile=1)
|
|
||||||
|
|
||||||
f2_tf = tf.function(
|
|
||||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
|
||||||
native_serialization=False),
|
|
||||||
autograph=False,
|
|
||||||
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
|
||||||
if jtu.test_device_matches(["tpu"]):
|
|
||||||
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
|
|
||||||
else:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
tf.errors.InvalidArgumentError,
|
|
||||||
r"Found inconsistency"):
|
|
||||||
_ = f2_tf(x45)
|
|
||||||
|
|
||||||
# We also catch the error with native serialization
|
# We also catch the error with native serialization
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
tf.errors.InvalidArgumentError,
|
tf.errors.InvalidArgumentError,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user