mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`.
35 lines
1.9 KiB
Python
35 lines
1.9 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
|
"maximum_supported_serialization_version",
|
|
"minimum_supported_serialization_version",
|
|
"default_lowering_platform",
|
|
"SymbolicScope", "is_symbolic_dim",
|
|
"symbolic_shape", "symbolic_args_specs"]
|
|
|
|
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
|
|
from jax._src.export._export import Exported as Exported
|
|
from jax._src.export._export import export as export
|
|
from jax._src.export._export import deserialize as deserialize
|
|
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
|
|
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
|
|
from jax._src.export._export import default_lowering_platform as default_lowering_platform
|
|
|
|
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
|
|
del shape_poly_decision
|
|
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
|
|
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
|
|
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
|
|
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs
|