1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
This commit is contained in:
Jake VanderPlas 2024-06-13 19:53:10 -07:00 committed by jax authors
parent 06ec7d1ad5
commit a92fa547a0
5 changed files with 87 additions and 22 deletions
CHANGELOG.md
docs/export
jax/experimental/export
tests

@ -17,6 +17,10 @@ Remember to align the itemized text with the first line of an item within a list
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).
* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
@ -24,6 +28,8 @@ Remember to align the itemized text with the first line of an item within a list
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
## jaxlib 0.4.30

@ -630,6 +630,29 @@ We list here a history of the calling convention version numbers:
This is the only supported version as of 27th of March, 2024.
## Migration guide from jax.experimental.export
On June 14, 2014 we deprecated the `jax.experimental.export` APIs
in favor of `jax.export` APIs. There have been some minor changes:
* `jax.experimental.export.export`:
* The old function used to allow any Python callable, or the result of
`jax.jit`. Now only the latter is accepted. You have to manually apply
`jax.jit` to the function to export before calling `export`.
* The old `lowering_parameters` kwarg is now named `platforms`
* `jax.experimental.export.default_lowering_platform()` is now
at {func}`jax.export.default_export_platform`.
* `jax.experimental.export.call` is now a method of the {class}`jax.export.Exported` object.
Instead of `export.call(exp)` you should use `exp.call`.
* `jax.experimental.export.serialize` is now a method of the {class}`jax.export.Exported`
object. Instead of `export.serialize(exp)` you should use `exp.serialize()`.
* The configuration flag `--jax-serialization-version` is deprecated.
Use `--jax-export-calling-convention-version`.
* The value `jax.experimental.export.minimum_supported_serialization_version`
is now at `jax.export.minimum_supported_calling_convention_version`.
* The following fields of {class}`jax.export.Exported` have been renamed
* `uses_shape_polymorphism` is now `uses_global_constants`
* `mlir_module_serialization_version` is now `calling_convention_version`
* `lowering_platforms` is now `platforms`.

@ -13,27 +13,61 @@
# limitations under the License.
# ==============================================================================
from jax._src.export._export import (
minimum_supported_calling_convention_version,
maximum_supported_calling_convention_version,
Exported,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform, # TODO: deprecate
_deprecation_message = (
"The jax.experimental.export module is deprecated. "
"Use jax.export instead. "
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
)
from jax._src.export._export import export_back_compat as export
from jax._src.export.shape_poly import (
is_symbolic_dim,
symbolic_shape,
symbolic_args_specs,
SymbolicScope,
)
from jax._src.export.serialization import (
serialize,
deserialize,
)
from jax._src.export import _export as _src_export
from jax._src.export import shape_poly as _src_shape_poly
from jax._src.export import serialization as _src_serialization
# Import only to set the shape poly decision procedure
from jax._src.export import shape_poly_decision
del shape_poly_decision
# All deprecations added Jun 14, 2024
_deprecations = {
# Added Jun 13, 2024
"Exported": (_deprecation_message, _src_export.Exported),
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
"export": (_deprecation_message, _src_export.export_back_compat),
"call": (_deprecation_message, _src_export.call),
"call_exported": (_deprecation_message, _src_export.call_exported),
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
"minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
"maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
"serialize": (_deprecation_message, _src_serialization.serialize),
"deserialize": (_deprecation_message, _src_serialization.deserialize),
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
}
import typing
if typing.TYPE_CHECKING:
Exported = _src_export.Exported
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
export = _src_export.export_back_compat
call = _src_export.call
call_exported = _src_export.call_exported
default_lowering_platform = _src_export.default_lowering_platform
serialize = _src_serialization.serialize
deserialize = _src_serialization.deserialize
SymbolicScope = _src_shape_poly.SymbolicScope
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
symbolic_shape = _src_shape_poly.symbolic_shape
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _src_export
del _src_serialization
del _src_shape_poly

@ -232,6 +232,8 @@ class JaxExportTest(jtu.JaxTestCase):
"Function to be exported must be the result of `jit`"):
_ = export.export(lambda x: jnp.sin(x))
@jtu.ignore_warning(category=DeprecationWarning,
message="The jax.experimental.export module is deprecated")
def test_export_experimental_back_compat(self):
from jax.experimental import export
# Can export a lambda, without jit

@ -17,7 +17,7 @@
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental import export
from jax import export
# Import mosaic for flag definitions
from jax.experimental import mosaic as _ # noqa: F401
from jax.experimental import pallas as pl
@ -49,7 +49,7 @@ class ExportTest(jtu.JaxTestCase):
if (jtu.device_under_test() == "tpu" or
(jtu.device_under_test() == "gpu" and
jtu.is_cuda_compute_capability_at_least("8.0"))):
res = export.call(exp)(a, a)
res = exp.call(a, a)
self.assertAllClose(res, a + a)