diff --git a/CHANGELOG.md b/CHANGELOG.md index 542a7d417..9b629631e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. + * The deprecated module `jax.experimental.export` has been removed. It was replaced + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + for information on migrating to the new API. ## jax 0.4.35 (Oct 22, 2024) diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD deleted file mode 100644 index 1246b0d40..000000000 --- a/jax/experimental/export/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# JAX-export provides APIs for exporting StableHLO for serialization purposes. - -load("@rules_python//python:defs.bzl", "py_library") -load( - "//jaxlib:jax.bzl", - "py_deps", -) - -licenses(["notice"]) - -package( - default_applicable_licenses = [], - default_visibility = ["//visibility:private"], -) - -py_library( - name = "export", - srcs = [ - "__init__.py", - ], - srcs_version = "PY3", - # TODO: b/255503696: enable pytype - tags = ["pytype_unchecked_annotations"], - visibility = ["//visibility:public"], - deps = [ - "//jax", - ] + py_deps("numpy") + py_deps("flatbuffers"), -) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py deleted file mode 100644 index d49aa2963..000000000 --- a/jax/experimental/export/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2023 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -_deprecation_message = ( - "The jax.experimental.export module is deprecated. " - "Use jax.export instead. " - "See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export." -) - -from jax._src.export import _export as _src_export -from jax._src.export import shape_poly as _src_shape_poly -from jax._src.export import serialization as _src_serialization -# Import only to set the shape poly decision procedure -from jax._src.export import shape_poly_decision -del shape_poly_decision - -# All deprecations added Jun 14, 2024 -_deprecations = { - # Added Jun 13, 2024 - "Exported": (_deprecation_message, _src_export.Exported), - "DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck), - "export": (_deprecation_message, _src_export.export_back_compat), - "call": (_deprecation_message, _src_export.call), - "call_exported": (_deprecation_message, _src_export.call_exported), - "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), - - "serialize": (_deprecation_message, _src_serialization.serialize), - "deserialize": (_deprecation_message, _src_serialization.deserialize), - - "SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope), - "is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim), - "symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape), - "symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs), -} - -import typing -if typing.TYPE_CHECKING: - Exported = _src_export.Exported - DisabledSafetyCheck = _src_export.DisabledSafetyCheck - export = _src_export.export_back_compat - call = _src_export.call - call_exported = _src_export.call_exported - default_lowering_platform = _src_export.default_lowering_platform - - serialize = _src_serialization.serialize - deserialize = _src_serialization.deserialize - - SymbolicScope = _src_shape_poly.SymbolicScope - is_symbolic_dim = _src_shape_poly.is_symbolic_dim - symbolic_shape = _src_shape_poly.symbolic_shape - symbolic_args_specs = _src_shape_poly.symbolic_args_specs -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _src_export -del _src_serialization -del _src_shape_poly diff --git a/tests/BUILD b/tests/BUILD index 316e98f5b..39e1d35a3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1409,9 +1409,6 @@ jax_multiplatform_test( "tpu_v3_2x2", ], tags = [], - deps = [ - "//jax/experimental/export", - ], ) jax_multiplatform_test( diff --git a/tests/export_test.py b/tests/export_test.py index fd6bef11e..b6dde2372 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -244,22 +244,6 @@ class JaxExportTest(jtu.JaxTestCase): "Function to be exported must be the result of `jit`"): _ = export.export(lambda x: jnp.sin(x)) - @jtu.ignore_warning(category=DeprecationWarning, - message="The jax.experimental.export module is deprecated") - def test_export_experimental_back_compat(self): - if not CAN_SERIALIZE: - self.skipTest("serialization disabled") - from jax.experimental import export - # Can export a lambda, without jit - exp = export.export(lambda x: jnp.sin(x))(.1) - self.assertAllClose(exp.call(1.), np.sin(1.)) - - blob = export.serialize(exp, vjp_order=1) - rehydrated = export.deserialize(blob) - - self.assertAllClose(export.call(exp)(1.), np.sin(1.)) - self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) - def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name f = jax.jit(lambda x: jnp.sin(x))