mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 20:26:05 +00:00

We take the opportunity of a new jax.export package to rename some of the API entry points: * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants` because this is more accurate. The dimension variables are global constants, but so is the platform index. And we need to run global constant propagation and shape refinement for all of these. * We rename "serialization version" with "calling convention version". Hence we now have `Exported.calling_convention_version`, and the configuration flag is renamed from `--jax-serialization-version` to `--jax-export-calling-convention-version`. Also, `jax.export.minimum_supported_serialization_version` is now `jax.export.minimum_supported_calling_convention_version`. * We rename `lowering_platforms` to `platforms` both as a field of `Exported` and as the kwarg to `export.export`. * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
|
"maximum_supported_calling_convention_version",
|
|
"minimum_supported_calling_convention_version",
|
|
"default_export_platform",
|
|
"SymbolicScope", "is_symbolic_dim",
|
|
"symbolic_shape", "symbolic_args_specs"]
|
|
|
|
from jax._src.export._export import (
|
|
DisabledSafetyCheck,
|
|
Exported,
|
|
export,
|
|
deserialize,
|
|
maximum_supported_calling_convention_version,
|
|
minimum_supported_calling_convention_version,
|
|
default_export_platform)
|
|
|
|
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
|
|
del shape_poly_decision
|
|
from jax._src.export.shape_poly import (
|
|
SymbolicScope,
|
|
is_symbolic_dim,
|
|
symbolic_shape,
|
|
symbolic_args_specs)
|