mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change requires changing the type of `jax.experimental.export.export` from a module to a function. This confuses pytype for the targets with strict type checking, which is why I attempt to make this change atomically throughout the internal code base. In order to support backwards compatibility with OSS packages, this change also includes explicit JAX version checks in several OSS packages, and also adds to the `export` function the attributes that the old export module had. PiperOrigin-RevId: 596563481
47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
|
|
|
|
load(
|
|
"//jaxlib:jax.bzl",
|
|
"py_deps",
|
|
)
|
|
load("@rules_python//python:defs.bzl", "py_library")
|
|
|
|
licenses(["notice"])
|
|
|
|
package(
|
|
default_applicable_licenses = [],
|
|
default_visibility = ["//visibility:private"],
|
|
)
|
|
|
|
py_library(
|
|
name = "export",
|
|
srcs = [
|
|
"__init__.py",
|
|
"_export.py",
|
|
"serialization.py",
|
|
"serialization_generated.py",
|
|
"shape_poly.py",
|
|
],
|
|
srcs_version = "PY3",
|
|
# TODO: b/255503696: enable pytype
|
|
tags = ["pytype_unchecked_annotations"],
|
|
visibility = ["//visibility:public"],
|
|
deps = [
|
|
"//jax",
|
|
] + py_deps("numpy") + py_deps("flatbuffers"),
|
|
)
|