mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change requires changing the type of `jax.experimental.export.export` from a module to a function. This confuses pytype for the targets with strict type checking, which is why I attempt to make this change atomically throughout the internal code base. In order to support backwards compatibility with OSS packages, this change also includes explicit JAX version checks in several OSS packages, and also adds to the `export` function the attributes that the old export module had. PiperOrigin-RevId: 596563481
1340 lines
59 KiB
Python
1340 lines
59 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""JAX APIs for exporting JAX functions for interoperation.
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import re
|
|
from typing import Any, Callable, Optional, TypeVar, Union
|
|
import warnings
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import sharding
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
|
from jax._src import pjit
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src import xla_bridge as xb
|
|
|
|
from jax.experimental.export import shape_poly
|
|
|
|
map = util.safe_map
|
|
zip = util.safe_zip
|
|
|
|
DType = Any
|
|
Shape = jax._src.core.Shape
|
|
# The values of input and output sharding from the lowering.
|
|
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
|
|
|
# None means unspecified sharding
|
|
Sharding = Union[xla_client.HloSharding, None]
|
|
|
|
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
|
# for a description of the different versions.
|
|
minimum_supported_serialization_version = 6
|
|
maximum_supported_serialization_version = 9
|
|
|
|
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
|
|
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
|
|
|
|
|
|
class DisabledSafetyCheck:
|
|
"""A safety check should be skipped on (de)serialization.
|
|
|
|
Most of these checks are performed on serialization, but some are deferred to
|
|
deserialization. The list of disabled checks is attached to the serialization,
|
|
e.g., as a sequence of string attributes to `jax_export.Exported` or of
|
|
`tf.XlaCallModuleOp`.
|
|
|
|
You can disable more deserialization safety checks by passing
|
|
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
|
|
"""
|
|
_impl: str
|
|
|
|
@classmethod
|
|
def platform(cls) -> DisabledSafetyCheck:
|
|
"""Allows the execution platform to differ from the serialization platform.
|
|
|
|
Has effect only on deserialization.
|
|
"""
|
|
return DisabledSafetyCheck("platform")
|
|
|
|
@classmethod
|
|
def custom_call(cls, target_name: str) -> DisabledSafetyCheck:
|
|
"""Allows the serialization of a call target not known to be stable.
|
|
|
|
Has effect only on serialization.
|
|
Args:
|
|
target_name: the name of the custom call target to allow.
|
|
"""
|
|
return DisabledSafetyCheck(f"custom_call:{target_name}")
|
|
|
|
@classmethod
|
|
def shape_assertions(cls) -> DisabledSafetyCheck:
|
|
"""Allows invocations with shapes that do not meet the constraints.
|
|
|
|
Has effect on serialization (to suppress the generation of the assertions)
|
|
and also on deserialization (to suppress the checking of the assertions).
|
|
"""
|
|
return DisabledSafetyCheck("shape_assertions")
|
|
|
|
def is_custom_call(self) -> str | None:
|
|
"""Returns the custom call target allowed by this directive."""
|
|
m = re.match(r'custom_call:(.+)$', self._impl)
|
|
return m.group(1) if m else None
|
|
|
|
def __init__(self, _impl:str):
|
|
# Do not use directly, use builders `platform`, `custom_call`.
|
|
self._impl = _impl
|
|
|
|
def __str__(self):
|
|
return self._impl
|
|
__repr__ = __str__
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self._impl)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Exported:
|
|
"""A JAX function lowered to StableHLO.
|
|
|
|
Attributes:
|
|
fun_name: the name of the exported function, for error messages.
|
|
in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX
|
|
function. The actual lowering does not depend on the `in_tree`, but this
|
|
can be used to invoke the exported function using the same argument
|
|
structure.
|
|
in_avals: the flat tuple of input abstract values. May contain dimension
|
|
expressions in the shapes.
|
|
out_tree: a PyTreeDef describing the result of the lowered JAX function.
|
|
out_avals: the flat tuple of output abstract values. May contain dimension
|
|
expressions in the shapes, with dimension variables among those in
|
|
`in_avals.
|
|
in_shardings: the flattened input shardings, as long as `in_avals`.
|
|
out_shardings: the flattened output shardings, as long as `out_avals`.
|
|
nr_devices: the number of devices that the module has been lowered for.
|
|
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
|
|
'cuda', 'rocm'. See below for the calling convention for when
|
|
there are multiple lowering platforms.
|
|
ordered_effects: the ordered effects present in the serialized module.
|
|
This is present from serialization version 9. See below for the
|
|
calling convention in presence of ordered effects.
|
|
unordered_effects: the unordered effects present in the serialized module.
|
|
This is present from serialization version 9.
|
|
mlir_module_serialized: the serialized lowered VHLO module.
|
|
mlir_module_serialization_version: a version number for the serialized module.
|
|
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
|
|
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
|
must be passed to the module. The other arguments have been dropped
|
|
because they are not used. Same length as `in_shardings`.
|
|
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
|
|
polymorphism. This may be because `in_avals` contains dimension
|
|
variables, or due to inner calls of Exported modules that have
|
|
dimension variables or platform index arguments. Such modules need
|
|
shape refinement before XLA compilation.
|
|
disabled_safety_checks: a list of descriptors of safety checks that have been
|
|
disabled at export time. See docstring for `DisabledSafetyCheck`.
|
|
_get_vjp: an optional function that takes the current exported function and
|
|
returns the exported VJP function.
|
|
The VJP function takes a flat list of arguments,
|
|
starting with the primal arguments and followed by a cotangent argument
|
|
for each primal output. It returns a tuple with the cotangents
|
|
corresponding to the flattened primal inputs.
|
|
|
|
Calling convention for the exported module (for latest supported version):
|
|
|
|
The `mlir_module` has a `main` function that takes an optional first
|
|
platform index argument if the module supports multiple platforms
|
|
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
|
|
to the ordered effects, followed by the kept array
|
|
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
|
|
The platform index is a i32 or i64 scalar encoding the index of the current
|
|
compilation platform into the `lowering_platforms` sequence.
|
|
|
|
Inner functions use a different calling convention: an optional
|
|
platform index argument, optional dimension variable arguments
|
|
(scalar tensors of type i32 or i64),
|
|
followed by optional token arguments (in presence of ordered effects),
|
|
followed by the regular array arguments.
|
|
The dimension arguments correspond to the dimension variables appearing in
|
|
the `args_avals`, in sorted order of their names.
|
|
|
|
Consider the lowering of a function with one array argument of type "f32[w,
|
|
2 * h]", where "w" and "h" are two dimension variables.
|
|
Assume that we use multi-platform lowering, and we have
|
|
one ordered effect. The `main` function will be as follows:
|
|
|
|
func public main(
|
|
platform_index: i32 {jax.global_constant="_platform_index"},
|
|
token_in: token,
|
|
arg: f32[?, ?]) {
|
|
arg_w = hlo.get_dimension_size(arg, 0)
|
|
dim1 = hlo.get_dimension_size(arg, 1)
|
|
arg_h = hlo.floordiv(dim1, 2)
|
|
call _check_shape_assertions(arg) # See below
|
|
token = new_token()
|
|
token_out, res = call _wrapped_jax_export_main(platform_index,
|
|
arg_h,
|
|
arg_w,
|
|
token_in,
|
|
arg)
|
|
return token_out, res
|
|
}
|
|
|
|
The actual computation is in `_wrapped_jax_export_main`, taking also
|
|
the values of `h` and `w` dimension variables.
|
|
|
|
The signature of the `_wrapped_jax_export_main` is:
|
|
|
|
func private _wrapped_jax_export_main(
|
|
platform_index: i32 {jax.global_constant="_platform_index"},
|
|
arg_h: i32 {jax.global_constant="h"},
|
|
arg_w: i32 {jax.global_constant="w"},
|
|
arg_token: stablehlo.token {jax.token=True},
|
|
arg: f32[?, ?]) -> (stablehlo.token, ...)
|
|
|
|
Prior to serialization version 9 the calling convention for effects is
|
|
different: the `main` function does not take or return a token. Instead
|
|
the function creates dummy tokens of type `i1[0]` and passes them to the
|
|
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
|
|
takes dummy tokens of type `i1[0]` and will create internally real
|
|
tokens to pass to the inner functions. The inner functions use real
|
|
tokens (both before and after serialization version 9)
|
|
|
|
Also starting with serialization version 9, function arguments that contain
|
|
the platform index or the dimension variable values have a
|
|
`jax.global_constant` string attribute whose value is the name of the
|
|
global constant, either `_platform_index` or a dimension variable name.
|
|
The global constant name may be empty if it is not known.
|
|
Some global constant computations use inner functions, e.g., for
|
|
`floor_divide`. The arguments of such functions have a `jax.global_constant`
|
|
attribute for all attributes, meaning that the result of the function is
|
|
also a global constant.
|
|
|
|
Note that `main` contains a call to `_check_shape_assertions.
|
|
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
|
|
have values >= 1. We must check these constraints when we invoke the
|
|
module. We use a special custom call `@shape_assertion` that takes
|
|
a boolean first operand, a string `error_message` attribute that may contain
|
|
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
|
|
scalar operands corresponding to the format specifiers.
|
|
|
|
func private _check_shape_assertions(arg: f32[?, ?]) {
|
|
# Check that w is >= 1
|
|
arg_w = hlo.get_dimension_size(arg, 0)
|
|
custom_call @shape_assertion(arg_w >= 1, arg_w,
|
|
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
|
|
# Check that dim1 is even
|
|
dim1 = hlo.get_dimension_size(arg, 1)
|
|
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
|
|
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
|
|
# Check that h >= 1
|
|
arg_h = hlo.floordiv(dim1, 2)
|
|
custom_call @shape_assertion(arg_h >= 1, arg_h,
|
|
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
|
|
|
If we `call_exported` with this module we perform these checks
|
|
statically (in `call_exported_abstract_eval`).
|
|
"""
|
|
fun_name: str
|
|
in_tree: tree_util.PyTreeDef
|
|
in_avals: tuple[core.AbstractValue, ...]
|
|
out_tree: tree_util.PyTreeDef
|
|
out_avals: tuple[core.AbstractValue, ...]
|
|
|
|
in_shardings: tuple[Sharding, ...]
|
|
out_shardings: tuple[Sharding, ...]
|
|
nr_devices: int
|
|
lowering_platforms: tuple[str, ...]
|
|
ordered_effects: tuple[effects.Effect, ...]
|
|
unordered_effects: tuple[effects.Effect, ...]
|
|
disabled_safety_checks: Sequence[DisabledSafetyCheck]
|
|
|
|
mlir_module_serialized: bytes
|
|
mlir_module_serialization_version: int
|
|
module_kept_var_idx: tuple[int, ...]
|
|
uses_shape_polymorphism: bool
|
|
|
|
_get_vjp: Callable[[Exported], Exported] | None
|
|
|
|
def mlir_module(self) -> ir.Module:
|
|
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
|
|
|
def __str__(self):
|
|
# This is called to make a MLIR source location when we call an Exported, and we
|
|
# do not want the entire serialized module to end up in locations.
|
|
return f"Exported(fun_name={self.fun_name}, ...)"
|
|
|
|
def has_vjp(self) -> bool:
|
|
return self._get_vjp is not None
|
|
|
|
def vjp(self) -> Exported:
|
|
"""Gets the exported VJP.
|
|
|
|
Returns None if not available, which can happen if the Exported has been
|
|
loaded from an external format, without a VJP."""
|
|
if self._get_vjp is None:
|
|
raise ValueError("No VJP is available")
|
|
return self._get_vjp(self)
|
|
|
|
|
|
def default_lowering_platform() -> str:
|
|
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
|
|
return xb.canonicalize_platform(jax.default_backend())
|
|
|
|
def symbolic_shape(
|
|
shape_spec: str | None,
|
|
*,
|
|
like: Sequence[int | None] | None = None) -> Shape:
|
|
"""Constructs a jax.ShapeDtypeStruct with polymorphic shapes.
|
|
|
|
Args:
|
|
shape_spec: a symbolic shape specification. None stands for "...".
|
|
like: when `shape_spec` contains placeholders ("_", "..."), use this
|
|
shape to fill in the placeholders.
|
|
The dimensions of `like` that are used for filling
|
|
must be known (not `None`). If a dimension in `like` is known and
|
|
the corresponding dimension in `shape_spec` is a constant then they
|
|
must be equal.
|
|
|
|
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
|
for more details.
|
|
|
|
Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic
|
|
expressions involving dimension variables.
|
|
"""
|
|
return shape_poly.symbolic_shape(shape_spec, like=like)
|
|
|
|
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
|
"""Returns the shape and dtype of a jax.Array."""
|
|
aval = core.raise_to_shaped(core.get_aval(a))
|
|
return aval.shape, aval.dtype
|
|
|
|
def args_specs(
|
|
args, # pytree of arguments
|
|
polymorphic_shapes, # prefix pytree of strings
|
|
get_shape_and_dtype=shape_and_dtype_jax_array,
|
|
):
|
|
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
|
|
|
|
Args:
|
|
args: a pytree of arguments
|
|
polymorphic_shapes: should be `None` (all arguments are monomorphic),
|
|
a single string (applies to all arguments), or a pytree matching a prefix
|
|
of the `args`.
|
|
See [how optional parameters are matched to
|
|
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
|
|
|
Note that this function does not ensure that the provided `args` shapes
|
|
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
|
|
used only to fill-in placeholders from `polymorphic_shapes`.
|
|
|
|
See docstring of `symbolic_shape` and
|
|
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
|
for more details.
|
|
get_shape_and_dtype: a function that given an argument extracts a tuple
|
|
of a shape and a dtype.
|
|
|
|
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
|
|
"""
|
|
args_flat, args_tree = tree_util.tree_flatten(args)
|
|
|
|
shapes_and_dtypes = tuple(map(get_shape_and_dtype, args_flat))
|
|
shapes, dtypes = util.unzip2(shapes_and_dtypes)
|
|
|
|
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
|
|
# TODO: Remove backward-compatibility workaround
|
|
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
|
else:
|
|
polymorphic_shapes_ = polymorphic_shapes
|
|
|
|
try:
|
|
polymorphic_shapes_flat = tree_util.broadcast_prefix(
|
|
polymorphic_shapes_, args,
|
|
is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
e, *_ = tree_util.prefix_errors(
|
|
polymorphic_shapes_, args,
|
|
is_leaf=lambda x: x is None)
|
|
raise e("jax_export polymorphic_shapes") from None
|
|
|
|
# Now add in the polymorphic shapes
|
|
args_specs_flat = (
|
|
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
|
|
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
|
|
|
|
return args_tree.unflatten(args_specs_flat)
|
|
|
|
|
|
def _keep_main_tokens(serialization_version: int) -> bool:
|
|
return serialization_version >= _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS
|
|
|
|
def export(fun_jax: Callable,
|
|
*,
|
|
lowering_platforms: Sequence[str] | None = None,
|
|
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
|
) -> Callable[..., Exported]:
|
|
"""Exports native serialization for a JAX function.
|
|
|
|
Args:
|
|
fun_jax: the function to lower and serialize.
|
|
lowering_platforms:
|
|
Optional sequence containing a subset of 'tpu', 'cpu',
|
|
'cuda', 'rocm'. If more than one platform is specified, then
|
|
the lowered code takes an argument specifying the platform.
|
|
If None, then use the default JAX backend.
|
|
The calling convention for multiple platforms is explained in the
|
|
`jax_export.Exported` docstring.
|
|
disabled_checks: the safety checks to disable. See docstring
|
|
of `DisabledSafetyCheck`.
|
|
|
|
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
|
or values with `.shape` and `.dtype` attributes, and returns an
|
|
`Exported`.
|
|
|
|
Usage:
|
|
|
|
def f_jax(*args, **kwargs): ...
|
|
exported = jax_export.export(f_jax)(*args, **kwargs)
|
|
"""
|
|
fun_name = getattr(fun_jax, "__name__", "unknown")
|
|
version = config.jax_serialization_version.value
|
|
if (version < minimum_supported_serialization_version or
|
|
version > maximum_supported_serialization_version):
|
|
raise ValueError(
|
|
f"The requested jax_serialization version {version} is outside the "
|
|
f"range of supported versions [{minimum_supported_serialization_version}"
|
|
f"..{maximum_supported_serialization_version}]")
|
|
|
|
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
|
if not hasattr(fun_jax, "lower"):
|
|
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
|
# convert(f_jax), in which case a "jit" is implied. In that case we raise
|
|
# an error if the lowered function contains non-replicated sharding annotations.
|
|
wrapped_fun_jax = jax.jit(fun_jax)
|
|
allow_non_replicated_sharding = False
|
|
else:
|
|
# If we have a pjit or pmap already we do not wrap with another, and we
|
|
# allow shardings.
|
|
wrapped_fun_jax = fun_jax # type: ignore
|
|
allow_non_replicated_sharding = True
|
|
|
|
if lowering_platforms is not None:
|
|
actual_lowering_platforms = tuple(lowering_platforms)
|
|
else:
|
|
actual_lowering_platforms = (default_lowering_platform(),)
|
|
|
|
# Do not include shape assertions if the version is < 7.
|
|
enable_shape_assertions = (
|
|
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
|
|
version >= _VERSION_START_SUPPORT_SHAPE_ASSERTIONS) # type: ignore
|
|
try:
|
|
prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions
|
|
shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
|
|
replace_tokens_with_dummy = not _keep_main_tokens(version)
|
|
lowered = wrapped_fun_jax.lower(
|
|
*args_specs, **kwargs_specs,
|
|
_experimental_lowering_parameters=mlir.LoweringParameters(
|
|
platforms=actual_lowering_platforms,
|
|
replace_tokens_with_dummy=replace_tokens_with_dummy,
|
|
))
|
|
|
|
lowering = lowered._lowering # type: ignore
|
|
_check_lowering(lowering)
|
|
mlir_module = lowering.stablehlo()
|
|
|
|
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
|
if "kept_var_idx" in lowering.compile_args:
|
|
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
|
|
else:
|
|
# For pmap
|
|
module_kept_var_idx = tuple(range(len(args_avals_flat)))
|
|
shape_poly_state = lowering.compile_args["shape_poly_state"]
|
|
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
|
|
or lowering.compile_args.get("ordered_effects", [])):
|
|
mlir_module = _wrap_main_func(
|
|
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
|
|
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
|
|
module_kept_var_idx=module_kept_var_idx,
|
|
serialization_version=version)
|
|
finally:
|
|
shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
|
|
|
|
with mlir_module.context:
|
|
mlir_module_attrs = mlir_module.operation.attributes
|
|
mlir_module_attrs["jax.uses_shape_polymorphism"] = (
|
|
mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars))
|
|
|
|
mlir_module_serialized = _module_to_bytecode(mlir_module)
|
|
|
|
# Figure out the result types and shapes
|
|
if "global_out_avals" in lowering.compile_args:
|
|
# This is currently the case for pjit
|
|
out_avals_flat = lowering.compile_args["global_out_avals"]
|
|
elif "shards" in lowering.compile_args: # for PmapComputation
|
|
out_avals_flat = lowering.compile_args["shards"].out_sharded_avals
|
|
else:
|
|
out_avals_flat = lowered.compile_args["out_avals"]
|
|
|
|
# Log and then check the module.
|
|
if logging.vlog_is_on(3):
|
|
logmsg = (f"version={version} "
|
|
f"lowering_platforms={actual_lowering_platforms} "
|
|
f"disabled_checks={disabled_checks}")
|
|
logging.info("Lowered JAX module: %s\n", logmsg)
|
|
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
|
|
logging.info("Dumped the exported MLIR module to %s", dumped_to)
|
|
|
|
_check_module(mlir_module,
|
|
allow_non_replicated_sharding=allow_non_replicated_sharding,
|
|
disabled_checks=disabled_checks)
|
|
|
|
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
|
|
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
|
|
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
|
ordered_effects = unordered_effects = ()
|
|
|
|
nr_devices = len(lowering.compile_args["device_assignment"])
|
|
def export_sharding(s: LoweringSharding,
|
|
aval: core.ShapedArray) -> Sharding:
|
|
if sharding_impls.is_unspecified(s):
|
|
return None
|
|
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
|
|
|
|
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
|
|
module_kept_var_idx,
|
|
len(args_avals_flat))
|
|
in_shardings = tuple(
|
|
export_sharding(s, aval)
|
|
for s, aval in zip(all_in_shardings, args_avals_flat))
|
|
out_shardings = tuple(
|
|
export_sharding(s, aval)
|
|
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
|
|
return Exported(
|
|
fun_name=fun_name,
|
|
in_tree=lowered.in_tree,
|
|
out_tree=lowered.out_tree,
|
|
in_avals=tuple(args_avals_flat),
|
|
out_avals=tuple(out_avals_flat),
|
|
in_shardings=in_shardings,
|
|
out_shardings=out_shardings,
|
|
nr_devices=nr_devices,
|
|
lowering_platforms=actual_lowering_platforms,
|
|
ordered_effects=ordered_effects,
|
|
unordered_effects=unordered_effects,
|
|
disabled_safety_checks=tuple(disabled_checks),
|
|
mlir_module_serialized=mlir_module_serialized,
|
|
module_kept_var_idx=module_kept_var_idx,
|
|
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
|
|
mlir_module_serialization_version=version, # type: ignore
|
|
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
|
|
|
|
return do_export
|
|
|
|
|
|
def _module_to_bytecode(module: ir.Module) -> bytes:
|
|
mlir_str = mlir.module_to_bytecode(module)
|
|
if hlo.get_api_version() < 4:
|
|
target_version = hlo.get_earliest_forward_compatible_version()
|
|
else:
|
|
# `target_version` is used to manage situations when a StableHLO producer
|
|
# (in this case, jax2tf) and a StableHLO consumer were built using
|
|
# different versions of StableHLO.
|
|
#
|
|
# Each StableHLO version `producer_version` has a compatibility window,
|
|
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
|
# where StableHLO portable artifacts serialized by `producer_version`
|
|
# can be deserialized by `consumer_version` within the window.
|
|
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
|
# for the exact extent of these compatibility guarantees.
|
|
#
|
|
# `hlo.get_minimum_version()` returns `consumer_version_min`
|
|
# for the current version of StableHLO. We are using it here to maximize
|
|
# forward compatibility, i.e. to maximize how far into the past we can go
|
|
# and still have the payloads produced by `serialize_portable_artifact`
|
|
# compatible with potential consumers from the past.
|
|
target_version = hlo.get_minimum_version()
|
|
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
|
mlir_str, target_version)
|
|
return module_serialized
|
|
|
|
|
|
def _wrap_main_func(
|
|
module: ir.Module,
|
|
args_avals_flat: Sequence[core.ShapedArray],
|
|
*,
|
|
args_kwargs_tree: tree_util.PyTreeDef,
|
|
has_platform_index_argument: bool,
|
|
module_kept_var_idx: tuple[int, ...],
|
|
serialization_version: int
|
|
) -> ir.Module:
|
|
"""Wraps the lowered module with a new "main" handling dimension arguments.
|
|
|
|
See calling convention documentation for `jax_export.Exported`.
|
|
|
|
Args:
|
|
module: the HLO module as obtained from lowering. See the calling convention
|
|
for inner functions in `jax_export.Exported`.
|
|
args_avals_flat: the avals for all the arguments of the lowered function,
|
|
which correspond to the array arguments of the `module`.
|
|
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
|
|
messages.
|
|
has_platform_index_argument: whether the `module` has a first platform
|
|
index argument
|
|
module_kept_var_idx: a sorted tuple of integers with the indices of arguments
|
|
in `args_avals_flat` that are kept as `module` arguments.
|
|
serialization_version: the target serialization version
|
|
|
|
Returns the wrapped module, without dimension and token arguments.
|
|
"""
|
|
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
|
context = mlir.make_ir_context()
|
|
with context, ir.Location.unknown(context):
|
|
# Make a copy, do not mutate because it may be cached
|
|
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module))
|
|
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
|
orig_main = symbol_table["main"]
|
|
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
|
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
|
|
|
def is_token(typ, attrs):
|
|
if typ == mlir.token_type()[0]:
|
|
return True
|
|
# TODO(b/302258959): in older versions we cannot use the token type
|
|
try:
|
|
return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value
|
|
except KeyError:
|
|
return False
|
|
|
|
orig_input_types = orig_main.type.inputs
|
|
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
|
|
# The order of args: platform_index_arg, dim args, token args, array args.
|
|
nr_platform_index_args = 1 if has_platform_index_argument else 0
|
|
nr_dim_args = len(dim_vars)
|
|
token_arg_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_input_types,
|
|
arg_attrs))
|
|
if is_token(typ, attrs)]
|
|
nr_token_args = len(token_arg_idxs)
|
|
if nr_token_args > 0:
|
|
assert min(token_arg_idxs) == nr_platform_index_args + nr_dim_args
|
|
assert token_arg_idxs == list(
|
|
range(nr_platform_index_args + nr_dim_args,
|
|
nr_platform_index_args + nr_dim_args + nr_token_args))
|
|
nr_array_args = (len(orig_input_types) - nr_platform_index_args
|
|
- nr_dim_args - nr_token_args)
|
|
assert nr_array_args >= 0
|
|
|
|
(platform_input_types, dim_var_input_types,
|
|
token_input_types, array_input_types) = util.split_list(
|
|
orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args])
|
|
|
|
# The order of results: tokens, array results
|
|
orig_output_types = orig_main.type.results
|
|
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
|
|
token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types,
|
|
result_attrs))
|
|
if is_token(typ, attrs)]
|
|
nr_token_results = len(token_result_idxs)
|
|
assert token_result_idxs == list(range(0, nr_token_results))
|
|
nr_array_results = len(orig_output_types) - nr_token_results
|
|
assert nr_array_results >= 0
|
|
if _keep_main_tokens(serialization_version):
|
|
new_main_arg_indices = (tuple(range(0, nr_platform_index_args)) +
|
|
tuple(range(nr_platform_index_args + nr_dim_args,
|
|
len(orig_input_types))))
|
|
new_main_result_indices = tuple(range(0, len(orig_output_types)))
|
|
else:
|
|
new_main_arg_indices = (
|
|
tuple(range(0, nr_platform_index_args)) +
|
|
tuple(range(nr_platform_index_args + nr_dim_args + nr_token_args,
|
|
len(orig_input_types))))
|
|
new_main_result_indices = tuple(range(nr_token_results, len(orig_output_types)))
|
|
new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices]
|
|
new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices]
|
|
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
|
|
new_main_op = func_dialect.FuncOp(
|
|
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
|
|
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
|
try:
|
|
new_arg_attrs = []
|
|
for idx in new_main_arg_indices:
|
|
new_arg_attr = {}
|
|
for attr in arg_attrs[idx]:
|
|
if attr.name == "tf.aliasing_output":
|
|
i = new_main_result_indices.index(attr.attr.value)
|
|
new_arg_attr[attr.name] = ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(32), i
|
|
)
|
|
else:
|
|
new_arg_attr[attr.name] = attr.attr
|
|
new_arg_attrs.append(ir.DictAttr.get(new_arg_attr))
|
|
new_main_op.arg_attrs = ir.ArrayAttr.get(new_arg_attrs)
|
|
except KeyError:
|
|
pass # TODO: better detection if orig_main.arg_attrs does not exist
|
|
try:
|
|
new_main_op.result_attrs = ir.ArrayAttr.get(
|
|
[result_attrs[idx] for idx in new_main_result_indices])
|
|
except KeyError:
|
|
pass
|
|
symbol_table.insert(new_main_op)
|
|
entry_block = new_main_op.add_entry_block()
|
|
with ir.InsertionPoint(entry_block):
|
|
# Make a context just for lowering the dimension value computations
|
|
module_context = mlir.ModuleContext(
|
|
backend_or_name="cpu", platforms=["cpu"],
|
|
axis_context=sharding_impls.ShardingContext(0),
|
|
name_stack=source_info_util.new_name_stack(),
|
|
keepalives=[], channel_iterator=itertools.count(1),
|
|
host_callbacks=[], module=wrapped_module, context=context,
|
|
lowering_parameters=mlir.LoweringParameters(
|
|
global_constant_computation=True
|
|
))
|
|
ctx = mlir.LoweringRuleContext(
|
|
module_context=module_context, primitive=None,
|
|
avals_in=args_avals_flat, avals_out=None,
|
|
tokens_in=mlir.TokenSet(), tokens_out=None)
|
|
# We compute dim_values from the array arguments.
|
|
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
|
|
if shape_poly.all_dim_vars(args_avals_flat):
|
|
# TODO(necula): handle module_kept_var_idx in presence of shape
|
|
# polymorphism. For now we ensured upstream that we keep all variables.
|
|
assert len(set(module_kept_var_idx)) == len(args_avals_flat)
|
|
dim_values = mlir.lower_fun(
|
|
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
|
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
|
|
multiple_results=True)(ctx, *new_main_op_array_args)
|
|
else:
|
|
dim_values = ()
|
|
# The arguments to pass to the call to orig_main
|
|
orig_main_args: list[ir.Value] = []
|
|
# The platform index and the dimension variables
|
|
for arg, arg_type in zip(
|
|
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
|
|
platform_input_types + dim_var_input_types):
|
|
if arg.type != arg_type:
|
|
orig_main_args.append(hlo.convert(arg_type, arg))
|
|
else:
|
|
orig_main_args.append(arg)
|
|
# Then the token arguments
|
|
if _keep_main_tokens(serialization_version):
|
|
orig_main_args.extend(
|
|
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
|
|
else:
|
|
orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args)
|
|
# Then the array arguments. We insert a ConvertOp as the only use of
|
|
# an input argument. This helps the downstream shape refinement because
|
|
# it will set the type of input arguments to static shapes, and this
|
|
# can invalidate the module if the argument is used as the result of a
|
|
# function, or if it appears as the input to a custom_call with
|
|
# output_operand_alias attribute. See b/287386268.
|
|
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
|
|
if arg.type != arg_type:
|
|
orig_main_args.append(hlo.convert(arg_type, arg))
|
|
else:
|
|
orig_main_args.append(arg)
|
|
call = func_dialect.CallOp(orig_output_types,
|
|
ir.FlatSymbolRefAttr.get(orig_main_name),
|
|
orig_main_args)
|
|
func_dialect.ReturnOp([call.results[idx] for idx in new_main_result_indices])
|
|
symbol_table.set_symbol_name(new_main_op, "main")
|
|
|
|
return wrapped_module
|
|
|
|
def _check_lowering(lowering) -> None:
|
|
if not isinstance(lowering, pxla.MeshComputation):
|
|
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
|
|
|
|
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
|
|
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
|
|
# Check that we do not see new compile_args. When we add a compile_args it is
|
|
# safe to add it to the allowed_compile_args if it does not change the semantics
|
|
# or the calling convention of the lowered module.
|
|
allowed_compile_args = [
|
|
"backend", "mesh", "global_in_avals",
|
|
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
|
"spmd_lowering", "auto_spmd_lowering",
|
|
"tuple_args", "ordered_effects", "unordered_effects",
|
|
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
|
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
|
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
|
|
for compile_arg in lowering.compile_args.keys():
|
|
if compile_arg not in allowed_compile_args:
|
|
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
|
|
|
# We have not implemented support for some of the compile_args. Check here that
|
|
# the compile_args have the values that have been implemented.
|
|
not_implemented_msgs = []
|
|
for compile_arg, check_value, err_msg in (
|
|
("spmd_lowering", lambda v: v, "True"),
|
|
("auto_spmd_lowering", lambda v: not v, "False"),
|
|
# tuple_args is a compilation flag, does not affect lowering.
|
|
("tuple_args", lambda v: True, "N/A"),
|
|
# unordered_effects do not change the calling convention. Those from
|
|
# jax.debug will also result in keepalive being non-empty and unsupported
|
|
# custom calls. The CallTfEffect is an exception, but we want to allow
|
|
# that one.
|
|
("unordered_effects", lambda v: True, "N/A"),
|
|
("ordered_effects", lambda v: True, "N/A"),
|
|
# used for TPU jax.debug, send/recv. Not supported yet.
|
|
("host_callbacks", lambda v: not v, "empty"),
|
|
# used on all platforms for callbacks. Not supported yet.
|
|
("keepalive", lambda v: not v, "empty"),
|
|
("pmap_nreps", lambda v: v == 1, "1"),
|
|
("shape_poly_state", lambda v: True, "N/A"),
|
|
):
|
|
if compile_arg in lowering.compile_args:
|
|
if not check_value(lowering.compile_args[compile_arg]):
|
|
not_implemented_msgs.append(
|
|
f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}")
|
|
if not_implemented_msgs:
|
|
raise NotImplementedError(
|
|
"serialization error, unimplemented lowered.compile_args:\n" +
|
|
"\n".join(not_implemented_msgs))
|
|
|
|
# These are the JAX custom call target names that are guaranteed to be stable.
|
|
# Their backwards compatibility is tested by back_compat_test.py.
|
|
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
|
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
|
"ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32",
|
|
# cholesky on CPU
|
|
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
|
# eigh on CPU
|
|
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
|
# eigh on GPU
|
|
"cusolver_syevj", "cusolver_syevd",
|
|
# eigh on TPU
|
|
"Eigh",
|
|
# eig on CPU
|
|
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
|
|
# qr on CPU
|
|
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
|
|
# householder product on CPU
|
|
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
|
|
# svd on CPU
|
|
"lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd",
|
|
# qr on GPU
|
|
"cusolver_geqrf", "cublas_geqrf_batched",
|
|
"cusolver_geqrf", "cusolver_orgqr",
|
|
# qr and svd on TPU
|
|
"Qr", "ProductOfElementaryHouseholderReflectors",
|
|
# triangular_solve on CPU
|
|
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
|
|
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
|
|
# # lu on CPU
|
|
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
|
|
# schur on CPU
|
|
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
|
# # lu on GPU
|
|
# "cublas_getrf_batched", "cusolver_getrf",
|
|
# "hipblas_getrf_batched", "hipsolver_getrf",
|
|
# lu on TPU
|
|
"LuDecomposition",
|
|
# ApproxTopK on TPU
|
|
"ApproxTopK",
|
|
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
|
|
"tpu_custom_call", # Pallas/TPU kernels
|
|
# TODO(burmako): maintain backwards compatibility for these, until they
|
|
# are upstreamed to StableHLO.
|
|
# See https://github.com/openxla/stablehlo/issues/8.
|
|
"stablehlo.dynamic_reduce_window",
|
|
"stablehlo.dynamic_rng_bit_generator",
|
|
"stablehlo.dynamic_top_k",
|
|
"shape_assertion", # Used by shape_poly to evaluate assertions
|
|
}
|
|
|
|
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")
|
|
|
|
def _check_module(mod: ir.Module, *,
|
|
allow_non_replicated_sharding: bool,
|
|
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
|
|
"""Run a number of checks on the module.
|
|
|
|
Args:
|
|
allow_non_replicated_sharding: whether the module is allowed to contain
|
|
non_replicated sharding annotations.
|
|
disabled_checks: the safety checks that are disabled.
|
|
"""
|
|
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
|
|
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
|
|
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
|
for dc in disabled_checks:
|
|
target = dc.is_custom_call()
|
|
if target is not None:
|
|
allowed_custom_call_targets.add(target)
|
|
|
|
allowed_custom_call_targets_attrs = {
|
|
ir.StringAttr.get(target, mod.context)
|
|
for target in allowed_custom_call_targets}
|
|
disallowed_custom_call_ops: list[str] = []
|
|
def check_sharding(op: ir.Operation, loc: ir.Location):
|
|
if not allow_non_replicated_sharding:
|
|
try:
|
|
sharding = op.attributes["mhlo.sharding"]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value):
|
|
raise ValueError(
|
|
"Lowered function does not have a top-level pjit but it has"
|
|
f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee"
|
|
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning"
|
|
" for a discussion."
|
|
)
|
|
|
|
def check_op(op: ir.Operation):
|
|
op_name = op.operation.name
|
|
if op_name == "func.func":
|
|
check_sharding(op.operation, op.location)
|
|
|
|
elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call":
|
|
call_target_name_attr = op.operation.attributes["call_target_name"]
|
|
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
|
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
|
if call_target_name_attr == sharding_attr:
|
|
check_sharding(op, op.location)
|
|
elif call_target_name_attr == shape_assertion_attr:
|
|
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)
|
|
|
|
def walk_operations(op):
|
|
check_op(op)
|
|
for region in op.operation.regions:
|
|
for block in region:
|
|
for op in block:
|
|
walk_operations(op)
|
|
|
|
walk_operations(mod)
|
|
if disallowed_custom_call_ops:
|
|
disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops)
|
|
msg = ("Cannot serialize code with custom calls whose targets have no "
|
|
"compatibility guarantees. Examples are:\n"
|
|
f"{disallowed_custom_call_ops_str}.\n"
|
|
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
|
|
raise ValueError(msg)
|
|
|
|
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
|
|
module_kept_var_idx: Sequence[int],
|
|
nr_inputs: int) -> Sequence[LoweringSharding]:
|
|
"""Expands in_shardings with unspecified shardings for inputs not kept.
|
|
|
|
Assumes in_shardings corresponds to module_kept_var_idx.
|
|
"""
|
|
assert len(in_shardings) == len(module_kept_var_idx)
|
|
assert nr_inputs >= len(module_kept_var_idx)
|
|
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
|
|
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
|
|
all_in_shardings[idx] = in_s
|
|
return tuple(all_in_shardings)
|
|
|
|
# TODO(yashkatariya, necula): remove this function once we relax the checks
|
|
# in the jit front-end.
|
|
def canonical_shardings(
|
|
device_assignment: Sequence[jax.Device],
|
|
in_shardings: Sequence[Sharding],
|
|
out_shardings: Sequence[Sharding]
|
|
) -> tuple[(pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding]),
|
|
(pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding])]:
|
|
"""Prepares canonical in_ and out_shardings for a pjit invocation.
|
|
|
|
The pjit front-end is picky about what in- and out-shardings it accepts,
|
|
e.g., if all are unspecified then the whole sharding should be the
|
|
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
|
|
replaced with the replicated sharding.
|
|
|
|
Returns: a pair with the canonicalized input and output shardings.
|
|
"""
|
|
replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment)
|
|
def canonicalize(
|
|
ss: Sequence[Sharding]) -> (pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding]):
|
|
if all(s is None for s in ss):
|
|
return sharding_impls.UNSPECIFIED
|
|
return tuple(
|
|
sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s
|
|
for s in ss)
|
|
return (canonicalize(in_shardings), canonicalize(out_shardings))
|
|
|
|
def _get_vjp_fun(primal_fun: Callable, *,
|
|
in_tree: tree_util.PyTreeDef,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue],
|
|
in_shardings: tuple[Sharding, ...],
|
|
out_shardings: tuple[Sharding, ...],
|
|
nr_devices: int,
|
|
apply_jit: bool
|
|
) -> tuple[Callable, Sequence[core.AbstractValue]]:
|
|
# Since jax.vjp does not handle kwargs, it is easier to do all the work
|
|
# here with flattened functions.
|
|
def fun_vjp_jax(*args_and_out_cts_flat_jax):
|
|
# Takes a flat list of primals and output cotangents
|
|
def flattened_primal_fun_jax(*args_flat):
|
|
args, kwargs = in_tree.unflatten(args_flat)
|
|
res = primal_fun(*args, **kwargs)
|
|
res_flat, _ = tree_util.tree_flatten(res)
|
|
return res_flat
|
|
|
|
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
|
|
[len(in_avals)])
|
|
_, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax)
|
|
return pullback_jax(out_cts_flat_jax)
|
|
|
|
vjp_in_avals = list(
|
|
itertools.chain(in_avals,
|
|
map(lambda a: a.at_least_vspace(), out_avals)))
|
|
|
|
if apply_jit:
|
|
# Prepare a device assignment. For exporting purposes, all it matters
|
|
# is the number of devices.
|
|
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
|
|
assert len(device_assignment) == nr_devices
|
|
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
|
|
device_assignment,
|
|
tuple(itertools.chain(in_shardings, out_shardings)),
|
|
in_shardings)
|
|
return pjit.pjit(fun_vjp_jax,
|
|
in_shardings=vjp_in_shardings,
|
|
out_shardings=vjp_out_shardings), vjp_in_avals
|
|
else:
|
|
return fun_vjp_jax, vjp_in_avals
|
|
|
|
def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
|
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
|
|
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
|
|
in_tree=primal.in_tree,
|
|
in_avals=primal.in_avals,
|
|
in_shardings=primal.in_shardings,
|
|
out_avals=primal.out_avals,
|
|
out_shardings=primal.out_shardings,
|
|
nr_devices=primal.nr_devices,
|
|
apply_jit=True)
|
|
return export(fun_vjp_jax,
|
|
lowering_platforms=primal.lowering_platforms,
|
|
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)
|
|
|
|
### Calling the exported function
|
|
|
|
def call(exported: Exported) -> Callable[..., jax.Array]:
|
|
if not isinstance(exported, Exported):
|
|
raise ValueError(
|
|
"The exported argument must be an export.Exported. "
|
|
f"Found {exported}.")
|
|
@jax.custom_vjp
|
|
def f_flat(*args_flat):
|
|
return call_exported_p.bind(*args_flat, exported=exported)
|
|
|
|
def f_flat_vjp_fwd(*args_flat):
|
|
# Return the primal arguments as the residual
|
|
# TODO: keep as residuals only the arguments that are needed
|
|
return f_flat(*args_flat), args_flat
|
|
|
|
def f_flat_vjp_bwd(residual, ct_res_flat):
|
|
args_flat = residual # residual is the primal argument flat tuple
|
|
exp_vjp = exported.vjp()
|
|
# ct_res_flat may contain arrays of zeros where exp_vjp expect float0.
|
|
# We make the proper arrays of float0 to invoke exp_vjp.
|
|
def fix_float0_ct(ct_res, expected_aval):
|
|
if expected_aval.dtype != dtypes.float0:
|
|
return ct_res
|
|
return ad_util.zeros_like_aval(expected_aval)
|
|
|
|
ct_res_fixed = map(fix_float0_ct,
|
|
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
|
|
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_fixed)
|
|
return in_ct_flat
|
|
|
|
f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd)
|
|
|
|
def f_imported(*args, **kwargs):
|
|
# since custom_vjp does not support kwargs, flatten the function first.
|
|
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
|
|
if in_tree != exported.in_tree:
|
|
# Give errors with the precise tree difference; use fake leaves so we can
|
|
# use tree_util.equality_errors.
|
|
in_args = in_tree.unflatten([0] * in_tree.num_leaves)
|
|
exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves)
|
|
|
|
msg = (
|
|
"The invocation args and kwargs must have the same pytree structure "
|
|
f"as when the function '{exported.fun_name}' was exported, but they "
|
|
"have the following structural differences:\n" +
|
|
("\n".join(
|
|
f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a "
|
|
f"{thing2} when exported, so {explanation}.\n"
|
|
for path, thing1, thing2, explanation
|
|
in tree_util.equality_errors(in_args, exp_in_args))))
|
|
raise ValueError(msg)
|
|
|
|
res_flat = f_flat(*args_flat)
|
|
return exported.out_tree.unflatten(res_flat)
|
|
return f_imported
|
|
|
|
call_exported = call
|
|
|
|
# A JAX primitive for invoking a serialized JAX function.
|
|
call_exported_p = core.Primitive("call_exported")
|
|
call_exported_p.multiple_results = True
|
|
|
|
@util.cache()
|
|
def _call_exported_abstract_eval(
|
|
*in_avals: core.AbstractValue,
|
|
exported: Exported
|
|
) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]:
|
|
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
|
|
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
|
|
# Check that the expected shapes match the actual ones
|
|
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
|
|
def pp_arg_dim(dim_idx: int | None) -> str:
|
|
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
|
|
arg_idx, dim_idx)
|
|
if len(exp_aval.shape) != len(actual_aval.shape):
|
|
raise ValueError(
|
|
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
|
|
f"and called with {actual_aval.shape}")
|
|
if exp_aval.dtype != actual_aval.dtype:
|
|
raise ValueError(
|
|
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
|
|
f"and called with {actual_aval.dtype}")
|
|
for dim_idx, aval_d in enumerate(exp_aval.shape):
|
|
# If the exp_aval has a constant dimension then the actual argument must have
|
|
# a matching constant dimension.
|
|
if core.is_constant_dim(aval_d):
|
|
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
|
|
aval_d != actual_aval.shape[dim_idx]):
|
|
raise ValueError(
|
|
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
|
|
"(expected same constant): "
|
|
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
|
|
|
|
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
|
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
|
|
exported.in_avals, args_kwargs_tree=exported.in_tree)
|
|
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
|
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
|
synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env)
|
|
# We discharge all the constraints statically. This results in much simpler
|
|
# composability (because we do not have to worry about the constraints of the
|
|
# Exported called recursively; we only need to worry about entry-point
|
|
# constraints). This also makes sense from a composability point of view,
|
|
# because we get the same errors if we invoke the exported module, or if we
|
|
# trace the exported function. Consider for example, an exported module with
|
|
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
|
|
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
|
|
# succeed and add a compile-time check that `c == d`. In the latter case,
|
|
# it would be ambiguous whether we should continue tracing with a result
|
|
# of type `f32[c]` or `f32[d]`.
|
|
shape_constraints.check_statically(synthetic_eval)
|
|
exported_dim_values = [synthetic_eval.evaluate(solution[var])
|
|
for var in exported_dim_vars]
|
|
out_avals = tuple(
|
|
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
|
*exported_dim_values),
|
|
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
|
|
named_shape=out_aval.named_shape)
|
|
for out_aval in exported.out_avals)
|
|
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
|
|
|
|
|
|
call_exported_p.def_effectful_abstract_eval(_call_exported_abstract_eval)
|
|
|
|
def _call_exported_impl(*args, exported: Exported):
|
|
return dispatch.apply_primitive(call_exported_p, *args, exported=exported)
|
|
|
|
call_exported_p.def_impl(_call_exported_impl)
|
|
|
|
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
|
exported: Exported):
|
|
if exported.uses_shape_polymorphism:
|
|
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
|
|
|
axis_context = ctx.module_context.axis_context
|
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
|
num_devices = axis_context.num_devices
|
|
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
|
num_devices = axis_context.mesh.size
|
|
else:
|
|
raise NotImplementedError(type(axis_context))
|
|
if num_devices != exported.nr_devices:
|
|
raise NotImplementedError(
|
|
f"Exported module {exported.fun_name} was lowered for "
|
|
f"{exported.nr_devices} devices and is called in a context with "
|
|
f"{num_devices} devices"
|
|
)
|
|
|
|
# Apply in_shardings
|
|
args = tuple(
|
|
wrap_with_sharding(ctx, x, x_aval, x_sharding)
|
|
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
|
|
submodule = ir.Module.parse(exported.mlir_module())
|
|
symtab = ir.SymbolTable(submodule.operation)
|
|
# The called function may have been exported with polymorphic shapes and called
|
|
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
|
|
# is valid.
|
|
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
|
|
new_ir_type = mlir.aval_to_ir_type(new_aval)
|
|
if x.type != new_ir_type:
|
|
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
|
|
else:
|
|
return x
|
|
|
|
callee_type = symtab["main"].type
|
|
# TODO: maybe cache multiple calls
|
|
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
|
f"call_exported_{exported.fun_name}",
|
|
submodule)
|
|
|
|
submodule_args = []
|
|
# All the platforms for the current lowering must be among the platforms
|
|
# for which the callee was lowered.
|
|
lowering_platforms = ctx.module_context.platforms
|
|
|
|
callee_lowering_platform_index: list[int] = []
|
|
for platform in lowering_platforms:
|
|
if platform in exported.lowering_platforms:
|
|
callee_lowering_platform_index.append(
|
|
exported.lowering_platforms.index(platform))
|
|
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
|
|
callee_lowering_platform_index.append(0)
|
|
else:
|
|
raise ValueError(
|
|
f"The exported function '{exported.fun_name}' was lowered for "
|
|
f"platforms '{exported.lowering_platforms}' but it is used "
|
|
f"on '{lowering_platforms}'.")
|
|
|
|
if len(exported.lowering_platforms) > 1:
|
|
# The exported module takes a platform index argument
|
|
if len(lowering_platforms) > 1:
|
|
current_platform_idx = ctx.dim_var_values[0]
|
|
else:
|
|
current_platform_idx = mlir.ir_constant(np.int32(0))
|
|
# Compute the rule index based on the current platform
|
|
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
|
if current_platform_idx.type != i32_type:
|
|
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
|
|
callee_platform_idx = hlo.CaseOp([i32_type],
|
|
index=current_platform_idx,
|
|
num_branches=len(lowering_platforms))
|
|
for i in range(len(lowering_platforms)):
|
|
branch = callee_platform_idx.regions[i].blocks.append()
|
|
with ir.InsertionPoint(branch):
|
|
hlo.return_(mlir.ir_constants(
|
|
np.int32(callee_lowering_platform_index[i])))
|
|
if callee_platform_idx.result.type != callee_type.inputs[0]:
|
|
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
|
|
callee_platform_idx)
|
|
|
|
submodule_args.append(callee_platform_idx)
|
|
else:
|
|
assert len(lowering_platforms) == 1
|
|
|
|
if _keep_main_tokens(exported.mlir_module_serialization_version):
|
|
ordered_effects = exported.ordered_effects
|
|
else:
|
|
ordered_effects = ()
|
|
for eff in ordered_effects:
|
|
token_in = ctx.tokens_in.get(eff)[0]
|
|
submodule_args.append(token_in)
|
|
kept_args = [
|
|
convert_shape(a, a_aval, exported_in_aval)
|
|
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
|
|
if i in exported.module_kept_var_idx]
|
|
submodule_args = submodule_args + kept_args
|
|
|
|
call = func_dialect.CallOp(callee_type.results,
|
|
ir.FlatSymbolRefAttr.get(fn),
|
|
submodule_args)
|
|
if ordered_effects:
|
|
tokens_out = {eff: (call.results[effect_idx],)
|
|
for effect_idx, eff in enumerate(ordered_effects)}
|
|
ctx.set_tokens_out(mlir.TokenSet(tokens_out))
|
|
# The ctx.avals_out already contain the abstract values refined by
|
|
# _call_exported_abstract_eval.
|
|
results = tuple(
|
|
convert_shape(out, out_aval, refined_out_aval)
|
|
for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):],
|
|
exported.out_avals, ctx.avals_out))
|
|
# Apply out_shardings
|
|
results = tuple(
|
|
wrap_with_sharding(ctx, x, x_aval, x_sharding)
|
|
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
|
|
)
|
|
return results
|
|
|
|
mlir.register_lowering(call_exported_p, _call_exported_lowering)
|
|
|
|
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
|
|
x: ir.Value,
|
|
x_aval: core.AbstractValue,
|
|
x_sharding: Sharding) -> ir.Value:
|
|
if x_sharding is None:
|
|
return x
|
|
return mlir.wrap_with_sharding_op(
|
|
ctx, x, x_aval, x_sharding.to_proto())
|
|
|
|
# TODO(necula): Previously, we had `from jax.experimental.export import export`
|
|
# Now we want to simplify the usage, and export the public APIs directly
|
|
# from `jax.experimental.export` and now `jax.experimental.export.export`
|
|
# refers to the `export` function. Since there may still be users of the
|
|
# old API in other packages, we add the old public API as attributes of the
|
|
# exported function. We will clean this up after a deprecation period.
|
|
def wrap_with_deprecation_warning(f):
|
|
msg = (f"You are using function `{f.__name__}` from "
|
|
"`jax.experimental.export.export`. You should instead use it directly "
|
|
"from `jax.experimental.export`. Instead of "
|
|
"`from jax.experimental.export import export` you should use "
|
|
"`from jax.experimental import export`.")
|
|
def wrapped_f(*args, **kwargs):
|
|
warnings.warn(msg, DeprecationWarning)
|
|
return f(*args, **kwargs)
|
|
return wrapped_f
|
|
|
|
export.export = wrap_with_deprecation_warning(export)
|
|
export.Exported = Exported
|
|
export.call_exported = wrap_with_deprecation_warning(call_exported)
|
|
export.DisabledSafetyCheck = DisabledSafetyCheck
|
|
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
|
|
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
|
|
export.args_specs = wrap_with_deprecation_warning(args_specs)
|
|
export.minimum_supported_serialization_version = minimum_supported_serialization_version
|
|
export.maximum_supported_serialization_version = maximum_supported_serialization_version
|