mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
This commit is contained in:
parent
f1c3109bf5
commit
e61a20b45a
@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
|
||||
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
|
||||
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
|
||||
* The deprecated module `jax.experimental.export` has been removed. It was replaced
|
||||
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
|
||||
for information on migrating to the new API.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
|
@ -1,42 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jax",
|
||||
] + py_deps("numpy") + py_deps("flatbuffers"),
|
||||
)
|
@ -1,73 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
_deprecation_message = (
|
||||
"The jax.experimental.export module is deprecated. "
|
||||
"Use jax.export instead. "
|
||||
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
|
||||
)
|
||||
|
||||
from jax._src.export import _export as _src_export
|
||||
from jax._src.export import shape_poly as _src_shape_poly
|
||||
from jax._src.export import serialization as _src_serialization
|
||||
# Import only to set the shape poly decision procedure
|
||||
from jax._src.export import shape_poly_decision
|
||||
del shape_poly_decision
|
||||
|
||||
# All deprecations added Jun 14, 2024
|
||||
_deprecations = {
|
||||
# Added Jun 13, 2024
|
||||
"Exported": (_deprecation_message, _src_export.Exported),
|
||||
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
|
||||
"export": (_deprecation_message, _src_export.export_back_compat),
|
||||
"call": (_deprecation_message, _src_export.call),
|
||||
"call_exported": (_deprecation_message, _src_export.call_exported),
|
||||
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
|
||||
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
|
||||
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
|
||||
|
||||
"serialize": (_deprecation_message, _src_serialization.serialize),
|
||||
"deserialize": (_deprecation_message, _src_serialization.deserialize),
|
||||
|
||||
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
|
||||
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
|
||||
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
|
||||
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
Exported = _src_export.Exported
|
||||
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
|
||||
export = _src_export.export_back_compat
|
||||
call = _src_export.call
|
||||
call_exported = _src_export.call_exported
|
||||
default_lowering_platform = _src_export.default_lowering_platform
|
||||
|
||||
serialize = _src_serialization.serialize
|
||||
deserialize = _src_serialization.deserialize
|
||||
|
||||
SymbolicScope = _src_shape_poly.SymbolicScope
|
||||
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
|
||||
symbolic_shape = _src_shape_poly.symbolic_shape
|
||||
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _src_export
|
||||
del _src_serialization
|
||||
del _src_shape_poly
|
@ -1409,9 +1409,6 @@ jax_multiplatform_test(
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
tags = [],
|
||||
deps = [
|
||||
"//jax/experimental/export",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
|
@ -244,22 +244,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Function to be exported must be the result of `jit`"):
|
||||
_ = export.export(lambda x: jnp.sin(x))
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="The jax.experimental.export module is deprecated")
|
||||
def test_export_experimental_back_compat(self):
|
||||
if not CAN_SERIALIZE:
|
||||
self.skipTest("serialization disabled")
|
||||
from jax.experimental import export
|
||||
# Can export a lambda, without jit
|
||||
exp = export.export(lambda x: jnp.sin(x))(.1)
|
||||
self.assertAllClose(exp.call(1.), np.sin(1.))
|
||||
|
||||
blob = export.serialize(exp, vjp_order=1)
|
||||
rehydrated = export.deserialize(blob)
|
||||
|
||||
self.assertAllClose(export.call(exp)(1.), np.sin(1.))
|
||||
self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.))
|
||||
|
||||
def test_call_exported_lambda(self):
|
||||
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
|
||||
f = jax.jit(lambda x: jnp.sin(x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user