mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
1298 lines
58 KiB
Python
1298 lines
58 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""JAX APIs for exporting JAX functions for interoperation.
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import re
|
|
from typing import Any, Callable, Union
|
|
import warnings
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import sharding
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
|
from jax._src import pjit
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src import xla_bridge as xb
|
|
|
|
from jax.experimental.export import _shape_poly
|
|
|
|
map = util.safe_map
|
|
zip = util.safe_zip
|
|
|
|
DType = Any
|
|
Shape = jax._src.core.Shape
|
|
# The values of input and output sharding from the lowering.
|
|
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
|
|
|
# None means unspecified sharding
|
|
Sharding = Union[xla_client.HloSharding, None]
|
|
|
|
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
|
# for a description of the different versions.
|
|
minimum_supported_serialization_version = 9
|
|
maximum_supported_serialization_version = 9
|
|
|
|
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
|
|
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
|
|
|
|
|
|
class DisabledSafetyCheck:
|
|
"""A safety check should be skipped on (de)serialization.
|
|
|
|
Most of these checks are performed on serialization, but some are deferred to
|
|
deserialization. The list of disabled checks is attached to the serialization,
|
|
e.g., as a sequence of string attributes to `jax_export.Exported` or of
|
|
`tf.XlaCallModuleOp`.
|
|
|
|
You can disable more deserialization safety checks by passing
|
|
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
|
|
"""
|
|
_impl: str
|
|
|
|
@classmethod
|
|
def platform(cls) -> DisabledSafetyCheck:
|
|
"""Allows the execution platform to differ from the serialization platform.
|
|
|
|
Has effect only on deserialization.
|
|
"""
|
|
return DisabledSafetyCheck("platform")
|
|
|
|
@classmethod
|
|
def custom_call(cls, target_name: str) -> DisabledSafetyCheck:
|
|
"""Allows the serialization of a call target not known to be stable.
|
|
|
|
Has effect only on serialization.
|
|
Args:
|
|
target_name: the name of the custom call target to allow.
|
|
"""
|
|
return DisabledSafetyCheck(f"custom_call:{target_name}")
|
|
|
|
@classmethod
|
|
def shape_assertions(cls) -> DisabledSafetyCheck:
|
|
"""Allows invocations with shapes that do not meet the constraints.
|
|
|
|
Has effect on serialization (to suppress the generation of the assertions)
|
|
and also on deserialization (to suppress the checking of the assertions).
|
|
"""
|
|
return DisabledSafetyCheck("shape_assertions")
|
|
|
|
def is_custom_call(self) -> str | None:
|
|
"""Returns the custom call target allowed by this directive."""
|
|
m = re.match(r'custom_call:(.+)$', self._impl)
|
|
return m.group(1) if m else None
|
|
|
|
def __init__(self, _impl:str):
|
|
# Do not use directly, use builders `platform`, `custom_call`.
|
|
self._impl = _impl
|
|
|
|
def __str__(self):
|
|
return self._impl
|
|
__repr__ = __str__
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self._impl)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Exported:
|
|
"""A JAX function lowered to StableHLO.
|
|
|
|
Attributes:
|
|
fun_name: the name of the exported function, for error messages.
|
|
in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX
|
|
function. The actual lowering does not depend on the `in_tree`, but this
|
|
can be used to invoke the exported function using the same argument
|
|
structure.
|
|
in_avals: the flat tuple of input abstract values. May contain dimension
|
|
expressions in the shapes.
|
|
out_tree: a PyTreeDef describing the result of the lowered JAX function.
|
|
out_avals: the flat tuple of output abstract values. May contain dimension
|
|
expressions in the shapes, with dimension variables among those in
|
|
`in_avals.
|
|
in_shardings: the flattened input shardings, as long as `in_avals`.
|
|
out_shardings: the flattened output shardings, as long as `out_avals`.
|
|
nr_devices: the number of devices that the module has been lowered for.
|
|
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
|
|
'cuda', 'rocm'. See below for the calling convention for when
|
|
there are multiple lowering platforms.
|
|
ordered_effects: the ordered effects present in the serialized module.
|
|
This is present from serialization version 9. See below for the
|
|
calling convention in presence of ordered effects.
|
|
unordered_effects: the unordered effects present in the serialized module.
|
|
This is present from serialization version 9.
|
|
mlir_module_serialized: the serialized lowered VHLO module.
|
|
mlir_module_serialization_version: a version number for the serialized module.
|
|
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
|
|
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
|
must be passed to the module. The other arguments have been dropped
|
|
because they are not used. Same length as `in_shardings`.
|
|
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
|
|
polymorphism. This may be because `in_avals` contains dimension
|
|
variables, or due to inner calls of Exported modules that have
|
|
dimension variables or platform index arguments. Such modules need
|
|
shape refinement before XLA compilation.
|
|
disabled_safety_checks: a list of descriptors of safety checks that have been
|
|
disabled at export time. See docstring for `DisabledSafetyCheck`.
|
|
_get_vjp: an optional function that takes the current exported function and
|
|
returns the exported VJP function.
|
|
The VJP function takes a flat list of arguments,
|
|
starting with the primal arguments and followed by a cotangent argument
|
|
for each primal output. It returns a tuple with the cotangents
|
|
corresponding to the flattened primal inputs.
|
|
|
|
Calling convention for the exported module (for latest supported version):
|
|
|
|
The `mlir_module` has a `main` function that takes an optional first
|
|
platform index argument if the module supports multiple platforms
|
|
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
|
|
to the ordered effects, followed by the kept array
|
|
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
|
|
The platform index is a i32 or i64 scalar encoding the index of the current
|
|
compilation platform into the `lowering_platforms` sequence.
|
|
|
|
Inner functions use a different calling convention: an optional
|
|
platform index argument, optional dimension variable arguments
|
|
(scalar tensors of type i32 or i64),
|
|
followed by optional token arguments (in presence of ordered effects),
|
|
followed by the regular array arguments.
|
|
The dimension arguments correspond to the dimension variables appearing in
|
|
the `args_avals`, in sorted order of their names.
|
|
|
|
Consider the lowering of a function with one array argument of type "f32[w,
|
|
2 * h]", where "w" and "h" are two dimension variables.
|
|
Assume that we use multi-platform lowering, and we have
|
|
one ordered effect. The `main` function will be as follows:
|
|
|
|
func public main(
|
|
platform_index: i32 {jax.global_constant="_platform_index"},
|
|
token_in: token,
|
|
arg: f32[?, ?]) {
|
|
arg_w = hlo.get_dimension_size(arg, 0)
|
|
dim1 = hlo.get_dimension_size(arg, 1)
|
|
arg_h = hlo.floordiv(dim1, 2)
|
|
call _check_shape_assertions(arg) # See below
|
|
token = new_token()
|
|
token_out, res = call _wrapped_jax_export_main(platform_index,
|
|
arg_h,
|
|
arg_w,
|
|
token_in,
|
|
arg)
|
|
return token_out, res
|
|
}
|
|
|
|
The actual computation is in `_wrapped_jax_export_main`, taking also
|
|
the values of `h` and `w` dimension variables.
|
|
|
|
The signature of the `_wrapped_jax_export_main` is:
|
|
|
|
func private _wrapped_jax_export_main(
|
|
platform_index: i32 {jax.global_constant="_platform_index"},
|
|
arg_h: i32 {jax.global_constant="h"},
|
|
arg_w: i32 {jax.global_constant="w"},
|
|
arg_token: stablehlo.token {jax.token=True},
|
|
arg: f32[?, ?]) -> (stablehlo.token, ...)
|
|
|
|
Prior to serialization version 9 the calling convention for effects is
|
|
different: the `main` function does not take or return a token. Instead
|
|
the function creates dummy tokens of type `i1[0]` and passes them to the
|
|
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
|
|
takes dummy tokens of type `i1[0]` and will create internally real
|
|
tokens to pass to the inner functions. The inner functions use real
|
|
tokens (both before and after serialization version 9)
|
|
|
|
Also starting with serialization version 9, function arguments that contain
|
|
the platform index or the dimension variable values have a
|
|
`jax.global_constant` string attribute whose value is the name of the
|
|
global constant, either `_platform_index` or a dimension variable name.
|
|
The global constant name may be empty if it is not known.
|
|
Some global constant computations use inner functions, e.g., for
|
|
`floor_divide`. The arguments of such functions have a `jax.global_constant`
|
|
attribute for all attributes, meaning that the result of the function is
|
|
also a global constant.
|
|
|
|
Note that `main` contains a call to `_check_shape_assertions.
|
|
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
|
|
have values >= 1. We must check these constraints when we invoke the
|
|
module. We use a special custom call `@shape_assertion` that takes
|
|
a boolean first operand, a string `error_message` attribute that may contain
|
|
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
|
|
scalar operands corresponding to the format specifiers.
|
|
|
|
func private _check_shape_assertions(arg: f32[?, ?]) {
|
|
# Check that w is >= 1
|
|
arg_w = hlo.get_dimension_size(arg, 0)
|
|
custom_call @shape_assertion(arg_w >= 1, arg_w,
|
|
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
|
|
# Check that dim1 is even
|
|
dim1 = hlo.get_dimension_size(arg, 1)
|
|
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
|
|
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
|
|
# Check that h >= 1
|
|
arg_h = hlo.floordiv(dim1, 2)
|
|
custom_call @shape_assertion(arg_h >= 1, arg_h,
|
|
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
|
|
|
If we `call_exported` with this module we perform these checks
|
|
statically (in `call_exported_abstract_eval`).
|
|
"""
|
|
fun_name: str
|
|
in_tree: tree_util.PyTreeDef
|
|
in_avals: tuple[core.AbstractValue, ...]
|
|
out_tree: tree_util.PyTreeDef
|
|
out_avals: tuple[core.AbstractValue, ...]
|
|
|
|
in_shardings: tuple[Sharding, ...]
|
|
out_shardings: tuple[Sharding, ...]
|
|
nr_devices: int
|
|
lowering_platforms: tuple[str, ...]
|
|
ordered_effects: tuple[effects.Effect, ...]
|
|
unordered_effects: tuple[effects.Effect, ...]
|
|
disabled_safety_checks: Sequence[DisabledSafetyCheck]
|
|
|
|
mlir_module_serialized: bytes
|
|
mlir_module_serialization_version: int
|
|
module_kept_var_idx: tuple[int, ...]
|
|
uses_shape_polymorphism: bool
|
|
|
|
_get_vjp: Callable[[Exported], Exported] | None
|
|
|
|
def mlir_module(self) -> ir.Module:
|
|
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
|
|
|
def __str__(self):
|
|
# This is called to make a MLIR source location when we call an Exported, and we
|
|
# do not want the entire serialized module to end up in locations.
|
|
return f"Exported(fun_name={self.fun_name}, ...)"
|
|
|
|
def has_vjp(self) -> bool:
|
|
return self._get_vjp is not None
|
|
|
|
def vjp(self) -> Exported:
|
|
"""Gets the exported VJP.
|
|
|
|
Returns None if not available, which can happen if the Exported has been
|
|
loaded from an external format, without a VJP."""
|
|
if self._get_vjp is None:
|
|
raise ValueError("No VJP is available")
|
|
return self._get_vjp(self)
|
|
|
|
|
|
def default_lowering_platform() -> str:
|
|
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
|
|
return xb.canonicalize_platform(jax.default_backend())
|
|
|
|
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
|
"""Returns the shape and dtype of a jax.Array or a j"""
|
|
if isinstance(a, jax.ShapeDtypeStruct):
|
|
return a.shape, a.dtype
|
|
aval = core.raise_to_shaped(core.get_aval(a))
|
|
return aval.shape, aval.dtype
|
|
|
|
def args_specs(
|
|
args, # pytree of arguments
|
|
polymorphic_shapes, # prefix pytree of strings
|
|
get_shape_and_dtype=shape_and_dtype_jax_array,
|
|
):
|
|
# TODO: deprecated in January 2024, to be removed.
|
|
warnings.warn(
|
|
"export.args_specs is deprecated in favor of export.symbolic_args_specs",
|
|
DeprecationWarning, stacklevel=2)
|
|
if get_shape_and_dtype is not shape_and_dtype_jax_array:
|
|
# This was needed in some older jax2tf implementations
|
|
args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)),
|
|
args)
|
|
return _shape_poly.symbolic_args_specs(args, polymorphic_shapes)
|
|
|
|
|
|
|
|
def _keep_main_tokens(serialization_version: int) -> bool:
|
|
return serialization_version >= _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS
|
|
|
|
def export(fun_jax: Callable,
|
|
*,
|
|
lowering_platforms: Sequence[str] | None = None,
|
|
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
|
) -> Callable[..., Exported]:
|
|
"""Exports native serialization for a JAX function.
|
|
|
|
Args:
|
|
fun_jax: the function to lower and serialize.
|
|
lowering_platforms:
|
|
Optional sequence containing a subset of 'tpu', 'cpu',
|
|
'cuda', 'rocm'. If more than one platform is specified, then
|
|
the lowered code takes an argument specifying the platform.
|
|
If None, then use the default JAX backend.
|
|
The calling convention for multiple platforms is explained in the
|
|
`jax_export.Exported` docstring.
|
|
disabled_checks: the safety checks to disable. See docstring
|
|
of `DisabledSafetyCheck`.
|
|
|
|
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
|
or values with `.shape` and `.dtype` attributes, and returns an
|
|
`Exported`.
|
|
|
|
Usage:
|
|
|
|
def f_jax(*args, **kwargs): ...
|
|
exported = jax_export.export(f_jax)(*args, **kwargs)
|
|
"""
|
|
fun_name = getattr(fun_jax, "__name__", "unknown")
|
|
version = config.jax_serialization_version.value
|
|
if (version < minimum_supported_serialization_version or
|
|
version > maximum_supported_serialization_version):
|
|
raise ValueError(
|
|
f"The requested jax_serialization version {version} is outside the "
|
|
f"range of supported versions [{minimum_supported_serialization_version}"
|
|
f"..{maximum_supported_serialization_version}]")
|
|
|
|
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
|
if not hasattr(fun_jax, "lower"):
|
|
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
|
# convert(f_jax), in which case a "jit" is implied. In that case we raise
|
|
# an error if the lowered function contains non-replicated sharding annotations.
|
|
wrapped_fun_jax = jax.jit(fun_jax)
|
|
allow_non_replicated_sharding = False
|
|
else:
|
|
# If we have a pjit or pmap already we do not wrap with another, and we
|
|
# allow shardings.
|
|
wrapped_fun_jax = fun_jax # type: ignore
|
|
allow_non_replicated_sharding = True
|
|
|
|
if lowering_platforms is not None:
|
|
actual_lowering_platforms = tuple(lowering_platforms)
|
|
else:
|
|
actual_lowering_platforms = (default_lowering_platform(),)
|
|
|
|
# Do not include shape assertions if the version is < 7.
|
|
enable_shape_assertions = (
|
|
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
|
|
version >= _VERSION_START_SUPPORT_SHAPE_ASSERTIONS) # type: ignore
|
|
try:
|
|
prev_enable_shape_assertions = _shape_poly.thread_local_state.enable_shape_assertions
|
|
_shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
|
|
replace_tokens_with_dummy = not _keep_main_tokens(version)
|
|
|
|
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
|
|
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
|
|
# Static args may has no `shape` attribute.
|
|
if not hasattr(aval, "shape"):
|
|
continue
|
|
for d in aval.shape:
|
|
if _shape_poly.is_symbolic_dim(d):
|
|
if symbolic_scope is None:
|
|
symbolic_scope = (d.scope, k_path)
|
|
continue
|
|
symbolic_scope[0]._check_same_scope(
|
|
d, when=f"when exporting {fun_name}",
|
|
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
|
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))
|
|
|
|
lowered = wrapped_fun_jax.lower(
|
|
*args_specs, **kwargs_specs,
|
|
_experimental_lowering_parameters=mlir.LoweringParameters(
|
|
platforms=actual_lowering_platforms,
|
|
replace_tokens_with_dummy=replace_tokens_with_dummy,
|
|
))
|
|
|
|
lowering = lowered._lowering # type: ignore
|
|
_check_lowering(lowering)
|
|
mlir_module = lowering.stablehlo()
|
|
|
|
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
|
if "mut" in lowering.compile_args:
|
|
if lowering.compile_args["mut"]: raise NotImplementedError
|
|
if "kept_var_idx" in lowering.compile_args:
|
|
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
|
|
else:
|
|
# For pmap
|
|
module_kept_var_idx = tuple(range(len(args_avals_flat)))
|
|
shape_poly_state = lowering.compile_args["shape_poly_state"]
|
|
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
|
|
or lowering.compile_args.get("ordered_effects", [])):
|
|
mlir_module = _wrap_main_func(
|
|
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
|
|
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
|
|
module_kept_var_idx=module_kept_var_idx,
|
|
serialization_version=version)
|
|
finally:
|
|
_shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
|
|
|
|
with mlir_module.context:
|
|
mlir_module_attrs = mlir_module.operation.attributes
|
|
mlir_module_attrs["jax.uses_shape_polymorphism"] = (
|
|
mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars))
|
|
|
|
mlir_module_serialized = _module_to_bytecode(mlir_module)
|
|
|
|
# Figure out the result types and shapes
|
|
if "global_out_avals" in lowering.compile_args:
|
|
# This is currently the case for pjit
|
|
out_avals_flat = lowering.compile_args["global_out_avals"]
|
|
elif "shards" in lowering.compile_args: # for PmapComputation
|
|
out_avals_flat = lowering.compile_args["shards"].out_sharded_avals
|
|
else:
|
|
out_avals_flat = lowered.compile_args["out_avals"]
|
|
|
|
# Log and then check the module.
|
|
if logging.vlog_is_on(3):
|
|
logmsg = (f"version={version} "
|
|
f"lowering_platforms={actual_lowering_platforms} "
|
|
f"disabled_checks={disabled_checks}")
|
|
logging.info("Lowered JAX module: %s\n", logmsg)
|
|
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
|
|
logging.info("Dumped the exported MLIR module to %s", dumped_to)
|
|
|
|
_check_module(mlir_module,
|
|
allow_non_replicated_sharding=allow_non_replicated_sharding,
|
|
disabled_checks=disabled_checks)
|
|
|
|
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
|
|
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
|
|
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
|
ordered_effects = unordered_effects = ()
|
|
|
|
nr_devices = len(lowering.compile_args["device_assignment"])
|
|
def export_sharding(s: LoweringSharding,
|
|
aval: core.ShapedArray) -> Sharding:
|
|
if sharding_impls.is_unspecified(s):
|
|
return None
|
|
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
|
|
|
|
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
|
|
module_kept_var_idx,
|
|
len(args_avals_flat))
|
|
in_shardings = tuple(
|
|
export_sharding(s, aval)
|
|
for s, aval in zip(all_in_shardings, args_avals_flat))
|
|
out_shardings = tuple(
|
|
export_sharding(s, aval)
|
|
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
|
|
return Exported(
|
|
fun_name=fun_name,
|
|
in_tree=lowered.in_tree,
|
|
out_tree=lowered.out_tree,
|
|
in_avals=tuple(args_avals_flat),
|
|
out_avals=tuple(out_avals_flat),
|
|
in_shardings=in_shardings,
|
|
out_shardings=out_shardings,
|
|
nr_devices=nr_devices,
|
|
lowering_platforms=actual_lowering_platforms,
|
|
ordered_effects=ordered_effects,
|
|
unordered_effects=unordered_effects,
|
|
disabled_safety_checks=tuple(disabled_checks),
|
|
mlir_module_serialized=mlir_module_serialized,
|
|
module_kept_var_idx=module_kept_var_idx,
|
|
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
|
|
mlir_module_serialization_version=version, # type: ignore
|
|
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
|
|
|
|
return do_export
|
|
|
|
|
|
def _module_to_bytecode(module: ir.Module) -> bytes:
|
|
mlir_str = mlir.module_to_bytecode(module)
|
|
if hlo.get_api_version() < 4:
|
|
target_version = hlo.get_earliest_forward_compatible_version()
|
|
else:
|
|
# `target_version` is used to manage situations when a StableHLO producer
|
|
# (in this case, jax2tf) and a StableHLO consumer were built using
|
|
# different versions of StableHLO.
|
|
#
|
|
# Each StableHLO version `producer_version` has a compatibility window,
|
|
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
|
# where StableHLO portable artifacts serialized by `producer_version`
|
|
# can be deserialized by `consumer_version` within the window.
|
|
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
|
# for the exact extent of these compatibility guarantees.
|
|
#
|
|
# `hlo.get_minimum_version()` returns `consumer_version_min`
|
|
# for the current version of StableHLO. We are using it here to maximize
|
|
# forward compatibility, i.e. to maximize how far into the past we can go
|
|
# and still have the payloads produced by `serialize_portable_artifact`
|
|
# compatible with potential consumers from the past.
|
|
target_version = hlo.get_minimum_version()
|
|
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
|
mlir_str, target_version)
|
|
return module_serialized
|
|
|
|
|
|
def _wrap_main_func(
|
|
module: ir.Module,
|
|
args_avals_flat: Sequence[core.ShapedArray],
|
|
*,
|
|
args_kwargs_tree: tree_util.PyTreeDef,
|
|
has_platform_index_argument: bool,
|
|
module_kept_var_idx: tuple[int, ...],
|
|
serialization_version: int
|
|
) -> ir.Module:
|
|
"""Wraps the lowered module with a new "main" handling dimension arguments.
|
|
|
|
See calling convention documentation for `jax_export.Exported`.
|
|
|
|
Args:
|
|
module: the HLO module as obtained from lowering. See the calling convention
|
|
for inner functions in `jax_export.Exported`.
|
|
args_avals_flat: the avals for all the arguments of the lowered function,
|
|
which correspond to the array arguments of the `module`.
|
|
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
|
|
messages.
|
|
has_platform_index_argument: whether the `module` has a first platform
|
|
index argument
|
|
module_kept_var_idx: a sorted tuple of integers with the indices of arguments
|
|
in `args_avals_flat` that are kept as `module` arguments.
|
|
serialization_version: the target serialization version
|
|
|
|
Returns the wrapped module, without dimension and token arguments.
|
|
"""
|
|
dim_vars = _shape_poly.all_dim_vars(args_avals_flat)
|
|
context = mlir.make_ir_context()
|
|
with context, ir.Location.unknown(context):
|
|
# Make a copy, do not mutate because it may be cached
|
|
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module))
|
|
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
|
orig_main = symbol_table["main"]
|
|
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
|
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
|
|
|
def is_token(typ, attrs):
|
|
if typ == mlir.token_type()[0]:
|
|
return True
|
|
# TODO(b/302258959): in older versions we cannot use the token type
|
|
try:
|
|
return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value
|
|
except KeyError:
|
|
return False
|
|
|
|
orig_input_types = orig_main.type.inputs
|
|
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
|
|
# The order of args: platform_index_arg, dim args, token args, array args.
|
|
nr_platform_index_args = 1 if has_platform_index_argument else 0
|
|
nr_dim_args = len(dim_vars)
|
|
token_arg_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_input_types,
|
|
arg_attrs))
|
|
if is_token(typ, attrs)]
|
|
nr_token_args = len(token_arg_idxs)
|
|
if nr_token_args > 0:
|
|
assert min(token_arg_idxs) == nr_platform_index_args + nr_dim_args
|
|
assert token_arg_idxs == list(
|
|
range(nr_platform_index_args + nr_dim_args,
|
|
nr_platform_index_args + nr_dim_args + nr_token_args))
|
|
nr_array_args = (len(orig_input_types) - nr_platform_index_args
|
|
- nr_dim_args - nr_token_args)
|
|
assert nr_array_args >= 0
|
|
|
|
(platform_input_types, dim_var_input_types,
|
|
token_input_types, array_input_types) = util.split_list(
|
|
orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args])
|
|
|
|
# The order of results: tokens, array results
|
|
orig_output_types = orig_main.type.results
|
|
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
|
|
token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types,
|
|
result_attrs))
|
|
if is_token(typ, attrs)]
|
|
nr_token_results = len(token_result_idxs)
|
|
assert token_result_idxs == list(range(0, nr_token_results))
|
|
nr_array_results = len(orig_output_types) - nr_token_results
|
|
assert nr_array_results >= 0
|
|
if _keep_main_tokens(serialization_version):
|
|
new_main_arg_indices = (tuple(range(0, nr_platform_index_args)) +
|
|
tuple(range(nr_platform_index_args + nr_dim_args,
|
|
len(orig_input_types))))
|
|
new_main_result_indices = tuple(range(0, len(orig_output_types)))
|
|
else:
|
|
new_main_arg_indices = (
|
|
tuple(range(0, nr_platform_index_args)) +
|
|
tuple(range(nr_platform_index_args + nr_dim_args + nr_token_args,
|
|
len(orig_input_types))))
|
|
new_main_result_indices = tuple(range(nr_token_results, len(orig_output_types)))
|
|
new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices]
|
|
new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices]
|
|
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
|
|
new_main_op = func_dialect.FuncOp(
|
|
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
|
|
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
|
try:
|
|
new_arg_attrs = []
|
|
for idx in new_main_arg_indices:
|
|
new_arg_attr = {}
|
|
for attr in arg_attrs[idx]:
|
|
if attr.name == "tf.aliasing_output":
|
|
i = new_main_result_indices.index(attr.attr.value)
|
|
new_arg_attr[attr.name] = ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(32), i
|
|
)
|
|
else:
|
|
new_arg_attr[attr.name] = attr.attr
|
|
new_arg_attrs.append(ir.DictAttr.get(new_arg_attr))
|
|
new_main_op.arg_attrs = ir.ArrayAttr.get(new_arg_attrs)
|
|
except KeyError:
|
|
pass # TODO: better detection if orig_main.arg_attrs does not exist
|
|
try:
|
|
new_main_op.result_attrs = ir.ArrayAttr.get(
|
|
[result_attrs[idx] for idx in new_main_result_indices])
|
|
except KeyError:
|
|
pass
|
|
symbol_table.insert(new_main_op)
|
|
entry_block = new_main_op.add_entry_block()
|
|
with ir.InsertionPoint(entry_block):
|
|
# Make a context just for lowering the dimension value computations
|
|
module_context = mlir.ModuleContext(
|
|
backend_or_name="cpu", platforms=["cpu"],
|
|
axis_context=sharding_impls.ShardingContext(0),
|
|
keepalives=[], channel_iterator=itertools.count(1),
|
|
host_callbacks=[], module=wrapped_module, context=context,
|
|
lowering_parameters=mlir.LoweringParameters(
|
|
global_constant_computation=True
|
|
))
|
|
ctx = mlir.LoweringRuleContext(
|
|
module_context=module_context,
|
|
name_stack=source_info_util.new_name_stack(), primitive=None,
|
|
avals_in=args_avals_flat, avals_out=None,
|
|
tokens_in=mlir.TokenSet(), tokens_out=None)
|
|
# We compute dim_values from the array arguments.
|
|
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
|
|
if _shape_poly.all_dim_vars(args_avals_flat):
|
|
# TODO(necula): handle module_kept_var_idx in presence of shape
|
|
# polymorphism. For now we ensured upstream that we keep all variables.
|
|
assert len(set(module_kept_var_idx)) == len(args_avals_flat)
|
|
dim_values = mlir.lower_fun(
|
|
functools.partial(_shape_poly.compute_dim_vars_from_arg_shapes,
|
|
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
|
|
multiple_results=True)(ctx, *new_main_op_array_args)
|
|
else:
|
|
dim_values = ()
|
|
# The arguments to pass to the call to orig_main
|
|
orig_main_args: list[ir.Value] = []
|
|
# The platform index and the dimension variables
|
|
for arg, arg_type in zip(
|
|
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
|
|
platform_input_types + dim_var_input_types):
|
|
if arg.type != arg_type:
|
|
orig_main_args.append(hlo.convert(arg_type, arg))
|
|
else:
|
|
orig_main_args.append(arg)
|
|
# Then the token arguments
|
|
if _keep_main_tokens(serialization_version):
|
|
orig_main_args.extend(
|
|
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
|
|
else:
|
|
orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args)
|
|
# Then the array arguments. We insert a ConvertOp as the only use of
|
|
# an input argument. This helps the downstream shape refinement because
|
|
# it will set the type of input arguments to static shapes, and this
|
|
# can invalidate the module if the argument is used as the result of a
|
|
# function, or if it appears as the input to a custom_call with
|
|
# output_operand_alias attribute. See b/287386268.
|
|
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
|
|
if arg.type != arg_type:
|
|
orig_main_args.append(hlo.convert(arg_type, arg))
|
|
else:
|
|
orig_main_args.append(arg)
|
|
call = func_dialect.CallOp(orig_output_types,
|
|
ir.FlatSymbolRefAttr.get(orig_main_name),
|
|
orig_main_args)
|
|
func_dialect.ReturnOp([call.results[idx] for idx in new_main_result_indices])
|
|
symbol_table.set_symbol_name(new_main_op, "main")
|
|
|
|
return wrapped_module
|
|
|
|
def _check_lowering(lowering) -> None:
|
|
if not isinstance(lowering, pxla.MeshComputation):
|
|
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
|
|
|
|
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
|
|
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
|
|
# Check that we do not see new compile_args. When we add a compile_args it is
|
|
# safe to add it to the allowed_compile_args if it does not change the semantics
|
|
# or the calling convention of the lowered module.
|
|
allowed_compile_args = [
|
|
"backend", "mesh", "global_in_avals",
|
|
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
|
"mut", "spmd_lowering", "auto_spmd_lowering",
|
|
"tuple_args", "ordered_effects", "unordered_effects",
|
|
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
|
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
|
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
|
|
for compile_arg in lowering.compile_args.keys():
|
|
if compile_arg not in allowed_compile_args:
|
|
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
|
|
|
# We have not implemented support for some of the compile_args. Check here that
|
|
# the compile_args have the values that have been implemented.
|
|
not_implemented_msgs = []
|
|
for compile_arg, check_value, err_msg in (
|
|
("spmd_lowering", lambda v: v, "True"),
|
|
("auto_spmd_lowering", lambda v: not v, "False"),
|
|
# tuple_args is a compilation flag, does not affect lowering.
|
|
("tuple_args", lambda v: True, "N/A"),
|
|
# unordered_effects do not change the calling convention. Those from
|
|
# jax.debug will also result in keepalive being non-empty and unsupported
|
|
# custom calls. The CallTfEffect is an exception, but we want to allow
|
|
# that one.
|
|
("unordered_effects", lambda v: True, "N/A"),
|
|
("ordered_effects", lambda v: True, "N/A"),
|
|
# used for TPU jax.debug, send/recv. Not supported yet.
|
|
("host_callbacks", lambda v: not v, "empty"),
|
|
# used on all platforms for callbacks. Not supported yet.
|
|
("keepalive", lambda v: not v, "empty"),
|
|
("pmap_nreps", lambda v: v == 1, "1"),
|
|
("shape_poly_state", lambda v: True, "N/A"),
|
|
):
|
|
if compile_arg in lowering.compile_args:
|
|
if not check_value(lowering.compile_args[compile_arg]):
|
|
not_implemented_msgs.append(
|
|
f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}")
|
|
if not_implemented_msgs:
|
|
raise NotImplementedError(
|
|
"serialization error, unimplemented lowered.compile_args:\n" +
|
|
"\n".join(not_implemented_msgs))
|
|
|
|
# These are the JAX custom call target names that are guaranteed to be stable.
|
|
# Their backwards compatibility is tested by back_compat_test.py.
|
|
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
|
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
|
"dynamic_ducc_fft", "cu_threefry2x32",
|
|
# cholesky on CPU
|
|
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
|
# eigh on CPU
|
|
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
|
# eigh on GPU
|
|
"cusolver_syevj", "cusolver_syevd",
|
|
# eigh on TPU
|
|
"Eigh",
|
|
# eig on CPU
|
|
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
|
|
# qr on CPU
|
|
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
|
|
# householder product on CPU
|
|
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
|
|
# svd on CPU
|
|
"lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd",
|
|
# qr on GPU
|
|
"cusolver_geqrf", "cublas_geqrf_batched",
|
|
"cusolver_orgqr",
|
|
# qr and svd on TPU
|
|
"Qr", "ProductOfElementaryHouseholderReflectors",
|
|
# triangular_solve on CPU
|
|
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
|
|
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
|
|
# # lu on CPU
|
|
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
|
|
# schur on CPU
|
|
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
|
# # lu on GPU
|
|
# "cublas_getrf_batched", "cusolver_getrf",
|
|
# "hipblas_getrf_batched", "hipsolver_getrf",
|
|
# lu on TPU
|
|
"LuDecomposition",
|
|
# ApproxTopK on TPU
|
|
"ApproxTopK",
|
|
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
|
|
"tpu_custom_call", # Pallas/TPU kernels
|
|
# TODO(burmako): maintain backwards compatibility for these, until they
|
|
# are upstreamed to StableHLO.
|
|
# See https://github.com/openxla/stablehlo/issues/8.
|
|
"stablehlo.dynamic_reduce_window",
|
|
"stablehlo.dynamic_rng_bit_generator",
|
|
"stablehlo.dynamic_top_k",
|
|
"shape_assertion", # Used by shape_poly to evaluate assertions
|
|
}
|
|
|
|
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")
|
|
|
|
def _check_module(mod: ir.Module, *,
|
|
allow_non_replicated_sharding: bool,
|
|
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
|
|
"""Run a number of checks on the module.
|
|
|
|
Args:
|
|
allow_non_replicated_sharding: whether the module is allowed to contain
|
|
non_replicated sharding annotations.
|
|
disabled_checks: the safety checks that are disabled.
|
|
"""
|
|
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
|
|
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
|
|
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
|
for dc in disabled_checks:
|
|
target = dc.is_custom_call()
|
|
if target is not None:
|
|
allowed_custom_call_targets.add(target)
|
|
|
|
allowed_custom_call_targets_attrs = {
|
|
ir.StringAttr.get(target, mod.context)
|
|
for target in allowed_custom_call_targets}
|
|
disallowed_custom_call_ops: list[str] = []
|
|
def check_sharding(op: ir.Operation, loc: ir.Location):
|
|
if not allow_non_replicated_sharding:
|
|
try:
|
|
sharding = op.attributes["mhlo.sharding"]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value):
|
|
raise ValueError(
|
|
"Lowered function does not have a top-level pjit but it has"
|
|
f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee"
|
|
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning"
|
|
" for a discussion."
|
|
)
|
|
|
|
def check_op(op: ir.Operation):
|
|
op_name = op.operation.name
|
|
if op_name == "func.func":
|
|
check_sharding(op.operation, op.location)
|
|
|
|
elif op_name == "stablehlo.custom_call":
|
|
call_target_name_attr = op.operation.attributes["call_target_name"]
|
|
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
|
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
|
if call_target_name_attr == sharding_attr:
|
|
check_sharding(op, op.location)
|
|
elif call_target_name_attr == shape_assertion_attr:
|
|
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)
|
|
|
|
def walk_operations(op):
|
|
check_op(op)
|
|
for region in op.operation.regions:
|
|
for block in region:
|
|
for op in block:
|
|
walk_operations(op)
|
|
|
|
walk_operations(mod)
|
|
if disallowed_custom_call_ops:
|
|
disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops)
|
|
msg = ("Cannot serialize code with custom calls whose targets have no "
|
|
"compatibility guarantees. Examples are:\n"
|
|
f"{disallowed_custom_call_ops_str}.\n"
|
|
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
|
|
raise ValueError(msg)
|
|
|
|
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
|
|
module_kept_var_idx: Sequence[int],
|
|
nr_inputs: int) -> Sequence[LoweringSharding]:
|
|
"""Expands in_shardings with unspecified shardings for inputs not kept.
|
|
|
|
Assumes in_shardings corresponds to module_kept_var_idx.
|
|
"""
|
|
assert len(in_shardings) == len(module_kept_var_idx)
|
|
assert nr_inputs >= len(module_kept_var_idx)
|
|
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
|
|
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
|
|
all_in_shardings[idx] = in_s
|
|
return tuple(all_in_shardings)
|
|
|
|
# TODO(yashkatariya, necula): remove this function once we relax the checks
|
|
# in the jit front-end.
|
|
def canonical_shardings(
|
|
device_assignment: Sequence[jax.Device],
|
|
in_shardings: Sequence[Sharding],
|
|
out_shardings: Sequence[Sharding]
|
|
) -> tuple[(pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding]),
|
|
(pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding])]:
|
|
"""Prepares canonical in_ and out_shardings for a pjit invocation.
|
|
|
|
The pjit front-end is picky about what in- and out-shardings it accepts,
|
|
e.g., if all are unspecified then the whole sharding should be the
|
|
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
|
|
replaced with the replicated sharding.
|
|
|
|
Returns: a pair with the canonicalized input and output shardings.
|
|
"""
|
|
replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment)
|
|
def canonicalize(
|
|
ss: Sequence[Sharding]) -> (pxla.UnspecifiedValue |
|
|
Sequence[sharding.XLACompatibleSharding]):
|
|
if all(s is None for s in ss):
|
|
return sharding_impls.UNSPECIFIED
|
|
return tuple(
|
|
sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s
|
|
for s in ss)
|
|
return (canonicalize(in_shardings), canonicalize(out_shardings))
|
|
|
|
def _get_vjp_fun(primal_fun: Callable, *,
|
|
in_tree: tree_util.PyTreeDef,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue],
|
|
in_shardings: tuple[Sharding, ...],
|
|
out_shardings: tuple[Sharding, ...],
|
|
nr_devices: int,
|
|
apply_jit: bool
|
|
) -> tuple[Callable, Sequence[core.AbstractValue]]:
|
|
# Since jax.vjp does not handle kwargs, it is easier to do all the work
|
|
# here with flattened functions.
|
|
def fun_vjp_jax(*args_and_out_cts_flat_jax):
|
|
# Takes a flat list of primals and output cotangents
|
|
def flattened_primal_fun_jax(*args_flat):
|
|
args, kwargs = in_tree.unflatten(args_flat)
|
|
res = primal_fun(*args, **kwargs)
|
|
res_flat, _ = tree_util.tree_flatten(res)
|
|
return res_flat
|
|
|
|
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
|
|
[len(in_avals)])
|
|
_, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax)
|
|
return pullback_jax(out_cts_flat_jax)
|
|
|
|
vjp_in_avals = list(
|
|
itertools.chain(in_avals,
|
|
map(lambda a: a.at_least_vspace(), out_avals)))
|
|
|
|
if apply_jit:
|
|
# Prepare a device assignment. For exporting purposes, all it matters
|
|
# is the number of devices.
|
|
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
|
|
assert len(device_assignment) == nr_devices
|
|
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
|
|
device_assignment,
|
|
tuple(itertools.chain(in_shardings, out_shardings)),
|
|
in_shardings)
|
|
return pjit.pjit(fun_vjp_jax,
|
|
in_shardings=vjp_in_shardings,
|
|
out_shardings=vjp_out_shardings), vjp_in_avals
|
|
else:
|
|
return fun_vjp_jax, vjp_in_avals
|
|
|
|
def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
|
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
|
|
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
|
|
in_tree=primal.in_tree,
|
|
in_avals=primal.in_avals,
|
|
in_shardings=primal.in_shardings,
|
|
out_avals=primal.out_avals,
|
|
out_shardings=primal.out_shardings,
|
|
nr_devices=primal.nr_devices,
|
|
apply_jit=True)
|
|
return export(fun_vjp_jax,
|
|
lowering_platforms=primal.lowering_platforms,
|
|
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)
|
|
|
|
### Calling the exported function
|
|
|
|
def call(exported: Exported) -> Callable[..., jax.Array]:
|
|
if not isinstance(exported, Exported):
|
|
raise ValueError(
|
|
"The exported argument must be an export.Exported. "
|
|
f"Found {exported}.")
|
|
@jax.custom_vjp
|
|
def f_flat(*args_flat):
|
|
return call_exported_p.bind(*args_flat, exported=exported)
|
|
|
|
def f_flat_vjp_fwd(*args_flat):
|
|
# Return the primal arguments as the residual
|
|
# TODO: keep as residuals only the arguments that are needed
|
|
return f_flat(*args_flat), args_flat
|
|
|
|
def f_flat_vjp_bwd(residual, ct_res_flat):
|
|
args_flat = residual # residual is the primal argument flat tuple
|
|
exp_vjp = exported.vjp()
|
|
# ct_res_flat may contain arrays of zeros where exp_vjp expect float0.
|
|
# We make the proper arrays of float0 to invoke exp_vjp.
|
|
def fix_float0_ct(ct_res, expected_aval):
|
|
if expected_aval.dtype != dtypes.float0:
|
|
return ct_res
|
|
return ad_util.zeros_like_aval(expected_aval)
|
|
|
|
ct_res_fixed = map(fix_float0_ct,
|
|
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
|
|
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_fixed)
|
|
return in_ct_flat
|
|
|
|
f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd)
|
|
|
|
def f_imported(*args, **kwargs):
|
|
# since custom_vjp does not support kwargs, flatten the function first.
|
|
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
|
|
if in_tree != exported.in_tree:
|
|
# Give errors with the precise tree difference; use fake leaves so we can
|
|
# use tree_util.equality_errors.
|
|
in_args = in_tree.unflatten([0] * in_tree.num_leaves)
|
|
exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves)
|
|
|
|
msg = (
|
|
"The invocation args and kwargs must have the same pytree structure "
|
|
f"as when the function '{exported.fun_name}' was exported, but they "
|
|
"have the following structural differences:\n" +
|
|
("\n".join(
|
|
f" - {_shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a "
|
|
f"{thing2} when exported, so {explanation}.\n"
|
|
for path, thing1, thing2, explanation
|
|
in tree_util.equality_errors(in_args, exp_in_args))))
|
|
raise ValueError(msg)
|
|
|
|
res_flat = f_flat(*args_flat)
|
|
return exported.out_tree.unflatten(res_flat)
|
|
return f_imported
|
|
|
|
call_exported = call
|
|
|
|
# A JAX primitive for invoking a serialized JAX function.
|
|
call_exported_p = core.Primitive("call_exported")
|
|
call_exported_p.multiple_results = True
|
|
|
|
@util.cache()
|
|
def _call_exported_abstract_eval(
|
|
*in_avals: core.AbstractValue,
|
|
exported: Exported
|
|
) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]:
|
|
exported_dim_vars = _shape_poly.all_dim_vars(exported.in_avals)
|
|
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
|
|
# Check that the expected shapes match the actual ones
|
|
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
|
|
def pp_arg_dim(dim_idx: int | None) -> str:
|
|
return _shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
|
|
arg_idx, dim_idx)
|
|
if len(exp_aval.shape) != len(actual_aval.shape):
|
|
raise ValueError(
|
|
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
|
|
f"and called with {actual_aval.shape}")
|
|
if exp_aval.dtype != actual_aval.dtype:
|
|
raise ValueError(
|
|
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
|
|
f"and called with {actual_aval.dtype}")
|
|
for dim_idx, aval_d in enumerate(exp_aval.shape):
|
|
# If the exp_aval has a constant dimension then the actual argument must have
|
|
# a matching constant dimension.
|
|
if core.is_constant_dim(aval_d):
|
|
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
|
|
aval_d != actual_aval.shape[dim_idx]):
|
|
raise ValueError(
|
|
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
|
|
"(expected same constant): "
|
|
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
|
|
|
|
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
|
solution, shape_constraints, synth_dim_vars = _shape_poly.solve_dim_vars(
|
|
exported.in_avals, args_kwargs_tree=exported.in_tree)
|
|
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
|
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
|
synthetic_eval = _shape_poly.CachingShapeEvaluator(**synthetic_env)
|
|
# We discharge all the constraints statically. This results in much simpler
|
|
# composability (because we do not have to worry about the constraints of the
|
|
# Exported called recursively; we only need to worry about entry-point
|
|
# constraints). This also makes sense from a composability point of view,
|
|
# because we get the same errors if we invoke the exported module, or if we
|
|
# trace the exported function. Consider for example, an exported module with
|
|
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
|
|
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
|
|
# succeed and add a compile-time check that `c == d`. In the latter case,
|
|
# it would be ambiguous whether we should continue tracing with a result
|
|
# of type `f32[c]` or `f32[d]`.
|
|
shape_constraints.check_statically(synthetic_eval)
|
|
exported_dim_values = [synthetic_eval.evaluate(solution[var])
|
|
for var in exported_dim_vars]
|
|
out_avals = tuple(
|
|
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
|
*exported_dim_values),
|
|
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
|
|
named_shape=out_aval.named_shape)
|
|
for out_aval in exported.out_avals)
|
|
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
|
|
|
|
|
|
call_exported_p.def_effectful_abstract_eval(_call_exported_abstract_eval)
|
|
|
|
def _call_exported_impl(*args, exported: Exported):
|
|
return dispatch.apply_primitive(call_exported_p, *args, exported=exported)
|
|
|
|
call_exported_p.def_impl(_call_exported_impl)
|
|
|
|
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
|
exported: Exported):
|
|
if exported.uses_shape_polymorphism:
|
|
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
|
|
|
axis_context = ctx.module_context.axis_context
|
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
|
num_devices = axis_context.num_devices
|
|
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
|
num_devices = axis_context.mesh.size
|
|
else:
|
|
raise NotImplementedError(type(axis_context))
|
|
if num_devices != exported.nr_devices:
|
|
raise NotImplementedError(
|
|
f"Exported module {exported.fun_name} was lowered for "
|
|
f"{exported.nr_devices} devices and is called in a context with "
|
|
f"{num_devices} devices"
|
|
)
|
|
|
|
# Apply in_shardings
|
|
args = tuple(
|
|
wrap_with_sharding(ctx, x, x_aval, x_sharding)
|
|
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
|
|
submodule = ir.Module.parse(exported.mlir_module())
|
|
symtab = ir.SymbolTable(submodule.operation)
|
|
# The called function may have been exported with polymorphic shapes and called
|
|
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
|
|
# is valid.
|
|
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
|
|
new_ir_type = mlir.aval_to_ir_type(new_aval)
|
|
if x.type != new_ir_type:
|
|
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
|
|
else:
|
|
return x
|
|
|
|
callee_type = symtab["main"].type
|
|
# TODO: maybe cache multiple calls
|
|
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
|
f"call_exported_{exported.fun_name}",
|
|
submodule)
|
|
|
|
submodule_args = []
|
|
# All the platforms for the current lowering must be among the platforms
|
|
# for which the callee was lowered.
|
|
lowering_platforms = ctx.module_context.platforms
|
|
|
|
callee_lowering_platform_index: list[int] = []
|
|
for platform in lowering_platforms:
|
|
if platform in exported.lowering_platforms:
|
|
callee_lowering_platform_index.append(
|
|
exported.lowering_platforms.index(platform))
|
|
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
|
|
callee_lowering_platform_index.append(0)
|
|
else:
|
|
raise ValueError(
|
|
f"The exported function '{exported.fun_name}' was lowered for "
|
|
f"platforms '{exported.lowering_platforms}' but it is used "
|
|
f"on '{lowering_platforms}'.")
|
|
|
|
if len(exported.lowering_platforms) > 1:
|
|
# The exported module takes a platform index argument
|
|
if len(lowering_platforms) > 1:
|
|
current_platform_idx = ctx.dim_var_values[0]
|
|
else:
|
|
current_platform_idx = mlir.ir_constant(np.int32(0))
|
|
# Compute the rule index based on the current platform
|
|
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
|
if current_platform_idx.type != i32_type:
|
|
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
|
|
callee_platform_idx = hlo.CaseOp([i32_type],
|
|
index=current_platform_idx,
|
|
num_branches=len(lowering_platforms))
|
|
for i in range(len(lowering_platforms)):
|
|
branch = callee_platform_idx.regions[i].blocks.append()
|
|
with ir.InsertionPoint(branch):
|
|
hlo.return_(mlir.ir_constants(
|
|
np.int32(callee_lowering_platform_index[i])))
|
|
if callee_platform_idx.result.type != callee_type.inputs[0]:
|
|
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
|
|
callee_platform_idx)
|
|
|
|
submodule_args.append(callee_platform_idx)
|
|
else:
|
|
assert len(lowering_platforms) == 1
|
|
|
|
if _keep_main_tokens(exported.mlir_module_serialization_version):
|
|
ordered_effects = exported.ordered_effects
|
|
else:
|
|
ordered_effects = ()
|
|
for eff in ordered_effects:
|
|
token_in = ctx.tokens_in.get(eff)[0]
|
|
submodule_args.append(token_in)
|
|
kept_args = [
|
|
convert_shape(a, a_aval, exported_in_aval)
|
|
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
|
|
if i in exported.module_kept_var_idx]
|
|
submodule_args = submodule_args + kept_args
|
|
|
|
call = func_dialect.CallOp(callee_type.results,
|
|
ir.FlatSymbolRefAttr.get(fn),
|
|
submodule_args)
|
|
if ordered_effects:
|
|
tokens_out = {eff: (call.results[effect_idx],)
|
|
for effect_idx, eff in enumerate(ordered_effects)}
|
|
ctx.set_tokens_out(mlir.TokenSet(tokens_out))
|
|
# The ctx.avals_out already contain the abstract values refined by
|
|
# _call_exported_abstract_eval.
|
|
results = tuple(
|
|
convert_shape(out, out_aval, refined_out_aval)
|
|
for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):],
|
|
exported.out_avals, ctx.avals_out))
|
|
# Apply out_shardings
|
|
results = tuple(
|
|
wrap_with_sharding(ctx, x, x_aval, x_sharding)
|
|
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
|
|
)
|
|
return results
|
|
|
|
mlir.register_lowering(call_exported_p, _call_exported_lowering)
|
|
|
|
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
|
|
x: ir.Value,
|
|
x_aval: core.AbstractValue,
|
|
x_sharding: Sharding) -> ir.Value:
|
|
if x_sharding is None:
|
|
return x
|
|
return mlir.wrap_with_sharding_op(
|
|
ctx, x, x_aval, x_sharding.to_proto())
|
|
|
|
# TODO(necula): Previously, we had `from jax.experimental.export import export`
|
|
# Now we want to simplify the usage, and export the public APIs directly
|
|
# from `jax.experimental.export` and now `jax.experimental.export.export`
|
|
# refers to the `export` function. Since there may still be users of the
|
|
# old API in other packages, we add the old public API as attributes of the
|
|
# exported function. We will clean this up after a deprecation period.
|
|
def wrap_with_deprecation_warning(f):
|
|
msg = (f"You are using function `{f.__name__}` from "
|
|
"`jax.experimental.export.export`. You should instead use it directly "
|
|
"from `jax.experimental.export`. Instead of "
|
|
"`from jax.experimental.export import export` you should use "
|
|
"`from jax.experimental import export`.")
|
|
def wrapped_f(*args, **kwargs):
|
|
warnings.warn(msg, DeprecationWarning)
|
|
return f(*args, **kwargs)
|
|
return wrapped_f
|
|
|
|
export.export = wrap_with_deprecation_warning(export)
|
|
export.Exported = Exported
|
|
export.call_exported = wrap_with_deprecation_warning(call_exported)
|
|
export.DisabledSafetyCheck = DisabledSafetyCheck
|
|
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
|
|
export.symbolic_shape = wrap_with_deprecation_warning(_shape_poly.symbolic_shape)
|
|
export.args_specs = wrap_with_deprecation_warning(args_specs)
|
|
export.minimum_supported_serialization_version = minimum_supported_serialization_version
|
|
export.maximum_supported_serialization_version = maximum_supported_serialization_version
|