George Necula 69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00

47 lines
1.3 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
load(
"//jaxlib:jax.bzl",
"py_deps",
)
load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
py_library(
name = "export",
srcs = [
"__init__.py",
"_export.py",
"serialization.py",
"serialization_generated.py",
"shape_poly.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps("numpy") + py_deps("flatbuffers"),
)