Remove deprecated jax.experimental.export module.

These tools are now available at jax.export.
This commit is contained in:
Jake VanderPlas 2024-10-30 05:27:29 -07:00
parent f1c3109bf5
commit e61a20b45a
5 changed files with 3 additions and 134 deletions

View File

@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
## jax 0.4.35 (Oct 22, 2024)

View File

@ -1,42 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
py_library(
name = "export",
srcs = [
"__init__.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps("numpy") + py_deps("flatbuffers"),
)

View File

@ -1,73 +0,0 @@
# Copyright 2023 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
_deprecation_message = (
"The jax.experimental.export module is deprecated. "
"Use jax.export instead. "
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
)
from jax._src.export import _export as _src_export
from jax._src.export import shape_poly as _src_shape_poly
from jax._src.export import serialization as _src_serialization
# Import only to set the shape poly decision procedure
from jax._src.export import shape_poly_decision
del shape_poly_decision
# All deprecations added Jun 14, 2024
_deprecations = {
# Added Jun 13, 2024
"Exported": (_deprecation_message, _src_export.Exported),
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
"export": (_deprecation_message, _src_export.export_back_compat),
"call": (_deprecation_message, _src_export.call),
"call_exported": (_deprecation_message, _src_export.call_exported),
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
"serialize": (_deprecation_message, _src_serialization.serialize),
"deserialize": (_deprecation_message, _src_serialization.deserialize),
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
}
import typing
if typing.TYPE_CHECKING:
Exported = _src_export.Exported
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
export = _src_export.export_back_compat
call = _src_export.call
call_exported = _src_export.call_exported
default_lowering_platform = _src_export.default_lowering_platform
serialize = _src_serialization.serialize
deserialize = _src_serialization.deserialize
SymbolicScope = _src_shape_poly.SymbolicScope
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
symbolic_shape = _src_shape_poly.symbolic_shape
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _src_export
del _src_serialization
del _src_shape_poly

View File

@ -1409,9 +1409,6 @@ jax_multiplatform_test(
"tpu_v3_2x2",
],
tags = [],
deps = [
"//jax/experimental/export",
],
)
jax_multiplatform_test(

View File

@ -244,22 +244,6 @@ class JaxExportTest(jtu.JaxTestCase):
"Function to be exported must be the result of `jit`"):
_ = export.export(lambda x: jnp.sin(x))
@jtu.ignore_warning(category=DeprecationWarning,
message="The jax.experimental.export module is deprecated")
def test_export_experimental_back_compat(self):
if not CAN_SERIALIZE:
self.skipTest("serialization disabled")
from jax.experimental import export
# Can export a lambda, without jit
exp = export.export(lambda x: jnp.sin(x))(.1)
self.assertAllClose(exp.call(1.), np.sin(1.))
blob = export.serialize(exp, vjp_order=1)
rehydrated = export.deserialize(blob)
self.assertAllClose(export.call(exp)(1.), np.sin(1.))
self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.))
def test_call_exported_lambda(self):
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = jax.jit(lambda x: jnp.sin(x))