mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b PiperOrigin-RevId: 643202512
This commit is contained in:
parent
06ec7d1ad5
commit
a92fa547a0
@ -17,6 +17,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* jax now depends on jaxlib directly. This change was enabled by the CUDA
|
||||
plugin switch: there are no longer multiple jaxlib variants. You can install
|
||||
a CPU-only jax with `pip install jax`, no extras required.
|
||||
* Added an API for exporting and serializing JAX functions. This used
|
||||
to exist in `jax.experimental.export` (which is being deprecated),
|
||||
and will now live in `jax.export`.
|
||||
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).
|
||||
|
||||
* Deprecations
|
||||
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
|
||||
@ -24,6 +28,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
|
||||
release. This previously was the case, but there was an inadvertent regression in
|
||||
the last several JAX releases.
|
||||
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
|
||||
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
|
||||
|
||||
## jaxlib 0.4.30
|
||||
|
||||
|
@ -630,6 +630,29 @@ We list here a history of the calling convention version numbers:
|
||||
This is the only supported version as of 27th of March, 2024.
|
||||
|
||||
|
||||
|
||||
## Migration guide from jax.experimental.export
|
||||
|
||||
On June 14, 2014 we deprecated the `jax.experimental.export` APIs
|
||||
in favor of `jax.export` APIs. There have been some minor changes:
|
||||
|
||||
* `jax.experimental.export.export`:
|
||||
* The old function used to allow any Python callable, or the result of
|
||||
`jax.jit`. Now only the latter is accepted. You have to manually apply
|
||||
`jax.jit` to the function to export before calling `export`.
|
||||
* The old `lowering_parameters` kwarg is now named `platforms`
|
||||
* `jax.experimental.export.default_lowering_platform()` is now
|
||||
at {func}`jax.export.default_export_platform`.
|
||||
* `jax.experimental.export.call` is now a method of the {class}`jax.export.Exported` object.
|
||||
Instead of `export.call(exp)` you should use `exp.call`.
|
||||
* `jax.experimental.export.serialize` is now a method of the {class}`jax.export.Exported`
|
||||
object. Instead of `export.serialize(exp)` you should use `exp.serialize()`.
|
||||
* The configuration flag `--jax-serialization-version` is deprecated.
|
||||
Use `--jax-export-calling-convention-version`.
|
||||
* The value `jax.experimental.export.minimum_supported_serialization_version`
|
||||
is now at `jax.export.minimum_supported_calling_convention_version`.
|
||||
* The following fields of {class}`jax.export.Exported` have been renamed
|
||||
* `uses_shape_polymorphism` is now `uses_global_constants`
|
||||
* `mlir_module_serialization_version` is now `calling_convention_version`
|
||||
* `lowering_platforms` is now `platforms`.
|
||||
|
||||
|
||||
|
@ -13,27 +13,61 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from jax._src.export._export import (
|
||||
minimum_supported_calling_convention_version,
|
||||
maximum_supported_calling_convention_version,
|
||||
Exported,
|
||||
call_exported, # TODO: deprecate
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform, # TODO: deprecate
|
||||
_deprecation_message = (
|
||||
"The jax.experimental.export module is deprecated. "
|
||||
"Use jax.export instead. "
|
||||
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
|
||||
)
|
||||
from jax._src.export._export import export_back_compat as export
|
||||
|
||||
from jax._src.export.shape_poly import (
|
||||
is_symbolic_dim,
|
||||
symbolic_shape,
|
||||
symbolic_args_specs,
|
||||
SymbolicScope,
|
||||
)
|
||||
from jax._src.export.serialization import (
|
||||
serialize,
|
||||
deserialize,
|
||||
)
|
||||
from jax._src.export import _export as _src_export
|
||||
from jax._src.export import shape_poly as _src_shape_poly
|
||||
from jax._src.export import serialization as _src_serialization
|
||||
# Import only to set the shape poly decision procedure
|
||||
from jax._src.export import shape_poly_decision
|
||||
del shape_poly_decision
|
||||
|
||||
# All deprecations added Jun 14, 2024
|
||||
_deprecations = {
|
||||
# Added Jun 13, 2024
|
||||
"Exported": (_deprecation_message, _src_export.Exported),
|
||||
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
|
||||
"export": (_deprecation_message, _src_export.export_back_compat),
|
||||
"call": (_deprecation_message, _src_export.call),
|
||||
"call_exported": (_deprecation_message, _src_export.call_exported),
|
||||
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
|
||||
"minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
|
||||
"maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
|
||||
|
||||
"serialize": (_deprecation_message, _src_serialization.serialize),
|
||||
"deserialize": (_deprecation_message, _src_serialization.deserialize),
|
||||
|
||||
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
|
||||
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
|
||||
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
|
||||
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
Exported = _src_export.Exported
|
||||
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
|
||||
export = _src_export.export_back_compat
|
||||
call = _src_export.call
|
||||
call_exported = _src_export.call_exported
|
||||
default_lowering_platform = _src_export.default_lowering_platform
|
||||
|
||||
serialize = _src_serialization.serialize
|
||||
deserialize = _src_serialization.deserialize
|
||||
|
||||
SymbolicScope = _src_shape_poly.SymbolicScope
|
||||
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
|
||||
symbolic_shape = _src_shape_poly.symbolic_shape
|
||||
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _src_export
|
||||
del _src_serialization
|
||||
del _src_shape_poly
|
||||
|
@ -232,6 +232,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Function to be exported must be the result of `jit`"):
|
||||
_ = export.export(lambda x: jnp.sin(x))
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="The jax.experimental.export module is deprecated")
|
||||
def test_export_experimental_back_compat(self):
|
||||
from jax.experimental import export
|
||||
# Can export a lambda, without jit
|
||||
|
@ -17,7 +17,7 @@
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
# Import mosaic for flag definitions
|
||||
from jax.experimental import mosaic as _ # noqa: F401
|
||||
from jax.experimental import pallas as pl
|
||||
@ -49,7 +49,7 @@ class ExportTest(jtu.JaxTestCase):
|
||||
if (jtu.device_under_test() == "tpu" or
|
||||
(jtu.device_under_test() == "gpu" and
|
||||
jtu.is_cuda_compute_capability_at_least("8.0"))):
|
||||
res = export.call(exp)(a, a)
|
||||
res = exp.call(a, a)
|
||||
self.assertAllClose(res, a + a)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user