rocm_jax/jax/export.py
George Necula b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00

35 lines
1.9 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
"maximum_supported_serialization_version",
"minimum_supported_serialization_version",
"default_lowering_platform",
"SymbolicScope", "is_symbolic_dim",
"symbolic_shape", "symbolic_args_specs"]
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
from jax._src.export._export import Exported as Exported
from jax._src.export._export import export as export
from jax._src.export._export import deserialize as deserialize
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
from jax._src.export._export import default_lowering_platform as default_lowering_platform
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
del shape_poly_decision
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs