George Necula d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00

1298 lines
58 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX APIs for exporting JAX functions for interoperation.
"""
from __future__ import annotations
from collections.abc import Sequence
import copy
import dataclasses
import functools
import itertools
import re
from typing import Any, Callable, Union
import warnings
from absl import logging
import numpy as np
import jax
from jax import sharding
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax.experimental.export import _shape_poly
map = util.safe_map
zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 9
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred to
deserialization. The list of disabled checks is attached to the serialization,
e.g., as a sequence of string attributes to `jax_export.Exported` or of
`tf.XlaCallModuleOp`.
You can disable more deserialization safety checks by passing
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
"""
_impl: str
@classmethod
def platform(cls) -> DisabledSafetyCheck:
"""Allows the execution platform to differ from the serialization platform.
Has effect only on deserialization.
"""
return DisabledSafetyCheck("platform")
@classmethod
def custom_call(cls, target_name: str) -> DisabledSafetyCheck:
"""Allows the serialization of a call target not known to be stable.
Has effect only on serialization.
Args:
target_name: the name of the custom call target to allow.
"""
return DisabledSafetyCheck(f"custom_call:{target_name}")
@classmethod
def shape_assertions(cls) -> DisabledSafetyCheck:
"""Allows invocations with shapes that do not meet the constraints.
Has effect on serialization (to suppress the generation of the assertions)
and also on deserialization (to suppress the checking of the assertions).
"""
return DisabledSafetyCheck("shape_assertions")
def is_custom_call(self) -> str | None:
"""Returns the custom call target allowed by this directive."""
m = re.match(r'custom_call:(.+)$', self._impl)
return m.group(1) if m else None
def __init__(self, _impl:str):
# Do not use directly, use builders `platform`, `custom_call`.
self._impl = _impl
def __str__(self):
return self._impl
__repr__ = __str__
def __eq__(self, other) -> bool:
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl
def __hash__(self) -> int:
return hash(self._impl)
@dataclasses.dataclass(frozen=True)
class Exported:
"""A JAX function lowered to StableHLO.
Attributes:
fun_name: the name of the exported function, for error messages.
in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX
function. The actual lowering does not depend on the `in_tree`, but this
can be used to invoke the exported function using the same argument
structure.
in_avals: the flat tuple of input abstract values. May contain dimension
expressions in the shapes.
out_tree: a PyTreeDef describing the result of the lowered JAX function.
out_avals: the flat tuple of output abstract values. May contain dimension
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings, as long as `in_avals`.
out_shardings: the flattened output shardings, as long as `out_avals`.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
there are multiple lowering platforms.
ordered_effects: the ordered effects present in the serialized module.
This is present from serialization version 9. See below for the
calling convention in presence of ordered effects.
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
mlir_module_serialization_version: a version number for the serialized module.
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used. Same length as `in_shardings`.
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
polymorphism. This may be because `in_avals` contains dimension
variables, or due to inner calls of Exported modules that have
dimension variables or platform index arguments. Such modules need
shape refinement before XLA compilation.
disabled_safety_checks: a list of descriptors of safety checks that have been
disabled at export time. See docstring for `DisabledSafetyCheck`.
_get_vjp: an optional function that takes the current exported function and
returns the exported VJP function.
The VJP function takes a flat list of arguments,
starting with the primal arguments and followed by a cotangent argument
for each primal output. It returns a tuple with the cotangents
corresponding to the flattened primal inputs.
Calling convention for the exported module (for latest supported version):
The `mlir_module` has a `main` function that takes an optional first
platform index argument if the module supports multiple platforms
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
to the ordered effects, followed by the kept array
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
The platform index is a i32 or i64 scalar encoding the index of the current
compilation platform into the `lowering_platforms` sequence.
Inner functions use a different calling convention: an optional
platform index argument, optional dimension variable arguments
(scalar tensors of type i32 or i64),
followed by optional token arguments (in presence of ordered effects),
followed by the regular array arguments.
The dimension arguments correspond to the dimension variables appearing in
the `args_avals`, in sorted order of their names.
Consider the lowering of a function with one array argument of type "f32[w,
2 * h]", where "w" and "h" are two dimension variables.
Assume that we use multi-platform lowering, and we have
one ordered effect. The `main` function will be as follows:
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
token_in: token,
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return token_out, res
}
The actual computation is in `_wrapped_jax_export_main`, taking also
the values of `h` and `w` dimension variables.
The signature of the `_wrapped_jax_export_main` is:
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?]) -> (stablehlo.token, ...)
Prior to serialization version 9 the calling convention for effects is
different: the `main` function does not take or return a token. Instead
the function creates dummy tokens of type `i1[0]` and passes them to the
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
takes dummy tokens of type `i1[0]` and will create internally real
tokens to pass to the inner functions. The inner functions use real
tokens (both before and after serialization version 9)
Also starting with serialization version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
The global constant name may be empty if it is not known.
Some global constant computations use inner functions, e.g., for
`floor_divide`. The arguments of such functions have a `jax.global_constant`
attribute for all attributes, meaning that the result of the function is
also a global constant.
Note that `main` contains a call to `_check_shape_assertions.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
module. We use a special custom call `@shape_assertion` that takes
a boolean first operand, a string `error_message` attribute that may contain
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
scalar operands corresponding to the format specifiers.
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
If we `call_exported` with this module we perform these checks
statically (in `call_exported_abstract_eval`).
"""
fun_name: str
in_tree: tree_util.PyTreeDef
in_avals: tuple[core.AbstractValue, ...]
out_tree: tree_util.PyTreeDef
out_avals: tuple[core.AbstractValue, ...]
in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
nr_devices: int
lowering_platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
disabled_safety_checks: Sequence[DisabledSafetyCheck]
mlir_module_serialized: bytes
mlir_module_serialization_version: int
module_kept_var_idx: tuple[int, ...]
uses_shape_polymorphism: bool
_get_vjp: Callable[[Exported], Exported] | None
def mlir_module(self) -> ir.Module:
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def __str__(self):
# This is called to make a MLIR source location when we call an Exported, and we
# do not want the entire serialized module to end up in locations.
return f"Exported(fun_name={self.fun_name}, ...)"
def has_vjp(self) -> bool:
return self._get_vjp is not None
def vjp(self) -> Exported:
"""Gets the exported VJP.
Returns None if not available, which can happen if the Exported has been
loaded from an external format, without a VJP."""
if self._get_vjp is None:
raise ValueError("No VJP is available")
return self._get_vjp(self)
def default_lowering_platform() -> str:
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
return aval.shape, aval.dtype
def args_specs(
args, # pytree of arguments
polymorphic_shapes, # prefix pytree of strings
get_shape_and_dtype=shape_and_dtype_jax_array,
):
# TODO: deprecated in January 2024, to be removed.
warnings.warn(
"export.args_specs is deprecated in favor of export.symbolic_args_specs",
DeprecationWarning, stacklevel=2)
if get_shape_and_dtype is not shape_and_dtype_jax_array:
# This was needed in some older jax2tf implementations
args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)),
args)
return _shape_poly.symbolic_args_specs(args, polymorphic_shapes)
def _keep_main_tokens(serialization_version: int) -> bool:
return serialization_version >= _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS
def export(fun_jax: Callable,
*,
lowering_platforms: Sequence[str] | None = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable[..., Exported]:
"""Exports native serialization for a JAX function.
Args:
fun_jax: the function to lower and serialize.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained in the
`jax_export.Exported` docstring.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:
def f_jax(*args, **kwargs): ...
exported = jax_export.export(f_jax)(*args, **kwargs)
"""
fun_name = getattr(fun_jax, "__name__", "unknown")
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")
def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
wrapped_fun_jax = jax.jit(fun_jax)
allow_non_replicated_sharding = False
else:
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax # type: ignore
allow_non_replicated_sharding = True
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)
# Do not include shape assertions if the version is < 7.
enable_shape_assertions = (
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
version >= _VERSION_START_SUPPORT_SHAPE_ASSERTIONS) # type: ignore
try:
prev_enable_shape_assertions = _shape_poly.thread_local_state.enable_shape_assertions
_shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
replace_tokens_with_dummy = not _keep_main_tokens(version)
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may has no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if _shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {fun_name}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
replace_tokens_with_dummy=replace_tokens_with_dummy,
))
lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "mut" in lowering.compile_args:
if lowering.compile_args["mut"]: raise NotImplementedError
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
module_kept_var_idx=module_kept_var_idx,
serialization_version=version)
finally:
_shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
with mlir_module.context:
mlir_module_attrs = mlir_module.operation.attributes
mlir_module_attrs["jax.uses_shape_polymorphism"] = (
mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars))
mlir_module_serialized = _module_to_bytecode(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowering.compile_args:
# This is currently the case for pjit
out_avals_flat = lowering.compile_args["global_out_avals"]
elif "shards" in lowering.compile_args: # for PmapComputation
out_avals_flat = lowering.compile_args["shards"].out_sharded_avals
else:
out_avals_flat = lowered.compile_args["out_avals"]
# Log and then check the module.
if logging.vlog_is_on(3):
logmsg = (f"version={version} "
f"lowering_platforms={actual_lowering_platforms} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
logging.info("Dumped the exported MLIR module to %s", dumped_to)
_check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding,
disabled_checks=disabled_checks)
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
ordered_effects = unordered_effects = ()
nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> Sharding:
if sharding_impls.is_unspecified(s):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
len(args_avals_flat))
in_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(all_in_shardings, args_avals_flat))
out_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
return Exported(
fun_name=fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
out_avals=tuple(out_avals_flat),
in_shardings=in_shardings,
out_shardings=out_shardings,
nr_devices=nr_devices,
lowering_platforms=actual_lowering_platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
mlir_module_serialization_version=version, # type: ignore
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
return do_export
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
if hlo.get_api_version() < 4:
target_version = hlo.get_earliest_forward_compatible_version()
else:
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `hlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = hlo.get_minimum_version()
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return module_serialized
def _wrap_main_func(
module: ir.Module,
args_avals_flat: Sequence[core.ShapedArray],
*,
args_kwargs_tree: tree_util.PyTreeDef,
has_platform_index_argument: bool,
module_kept_var_idx: tuple[int, ...],
serialization_version: int
) -> ir.Module:
"""Wraps the lowered module with a new "main" handling dimension arguments.
See calling convention documentation for `jax_export.Exported`.
Args:
module: the HLO module as obtained from lowering. See the calling convention
for inner functions in `jax_export.Exported`.
args_avals_flat: the avals for all the arguments of the lowered function,
which correspond to the array arguments of the `module`.
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
messages.
has_platform_index_argument: whether the `module` has a first platform
index argument
module_kept_var_idx: a sorted tuple of integers with the indices of arguments
in `args_avals_flat` that are kept as `module` arguments.
serialization_version: the target serialization version
Returns the wrapped module, without dimension and token arguments.
"""
dim_vars = _shape_poly.all_dim_vars(args_avals_flat)
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
# Make a copy, do not mutate because it may be cached
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module))
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
def is_token(typ, attrs):
if typ == mlir.token_type()[0]:
return True
# TODO(b/302258959): in older versions we cannot use the token type
try:
return ir.BoolAttr(ir.DictAttr(attrs)["jax.token"]).value
except KeyError:
return False
orig_input_types = orig_main.type.inputs
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
# The order of args: platform_index_arg, dim args, token args, array args.
nr_platform_index_args = 1 if has_platform_index_argument else 0
nr_dim_args = len(dim_vars)
token_arg_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_input_types,
arg_attrs))
if is_token(typ, attrs)]
nr_token_args = len(token_arg_idxs)
if nr_token_args > 0:
assert min(token_arg_idxs) == nr_platform_index_args + nr_dim_args
assert token_arg_idxs == list(
range(nr_platform_index_args + nr_dim_args,
nr_platform_index_args + nr_dim_args + nr_token_args))
nr_array_args = (len(orig_input_types) - nr_platform_index_args
- nr_dim_args - nr_token_args)
assert nr_array_args >= 0
(platform_input_types, dim_var_input_types,
token_input_types, array_input_types) = util.split_list(
orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args])
# The order of results: tokens, array results
orig_output_types = orig_main.type.results
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types,
result_attrs))
if is_token(typ, attrs)]
nr_token_results = len(token_result_idxs)
assert token_result_idxs == list(range(0, nr_token_results))
nr_array_results = len(orig_output_types) - nr_token_results
assert nr_array_results >= 0
if _keep_main_tokens(serialization_version):
new_main_arg_indices = (tuple(range(0, nr_platform_index_args)) +
tuple(range(nr_platform_index_args + nr_dim_args,
len(orig_input_types))))
new_main_result_indices = tuple(range(0, len(orig_output_types)))
else:
new_main_arg_indices = (
tuple(range(0, nr_platform_index_args)) +
tuple(range(nr_platform_index_args + nr_dim_args + nr_token_args,
len(orig_input_types))))
new_main_result_indices = tuple(range(nr_token_results, len(orig_output_types)))
new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices]
new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices]
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
new_main_op = func_dialect.FuncOp(
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
try:
new_arg_attrs = []
for idx in new_main_arg_indices:
new_arg_attr = {}
for attr in arg_attrs[idx]:
if attr.name == "tf.aliasing_output":
i = new_main_result_indices.index(attr.attr.value)
new_arg_attr[attr.name] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), i
)
else:
new_arg_attr[attr.name] = attr.attr
new_arg_attrs.append(ir.DictAttr.get(new_arg_attr))
new_main_op.arg_attrs = ir.ArrayAttr.get(new_arg_attrs)
except KeyError:
pass # TODO: better detection if orig_main.arg_attrs does not exist
try:
new_main_op.result_attrs = ir.ArrayAttr.get(
[result_attrs[idx] for idx in new_main_result_indices])
except KeyError:
pass
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"],
axis_context=sharding_impls.ShardingContext(0),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
lowering_parameters=mlir.LoweringParameters(
global_constant_computation=True
))
ctx = mlir.LoweringRuleContext(
module_context=module_context,
name_stack=source_info_util.new_name_stack(), primitive=None,
avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
# We compute dim_values from the array arguments.
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
if _shape_poly.all_dim_vars(args_avals_flat):
# TODO(necula): handle module_kept_var_idx in presence of shape
# polymorphism. For now we ensured upstream that we keep all variables.
assert len(set(module_kept_var_idx)) == len(args_avals_flat)
dim_values = mlir.lower_fun(
functools.partial(_shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *new_main_op_array_args)
else:
dim_values = ()
# The arguments to pass to the call to orig_main
orig_main_args: list[ir.Value] = []
# The platform index and the dimension variables
for arg, arg_type in zip(
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
platform_input_types + dim_var_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
# Then the token arguments
if _keep_main_tokens(serialization_version):
orig_main_args.extend(
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
else:
orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args)
# Then the array arguments. We insert a ConvertOp as the only use of
# an input argument. This helps the downstream shape refinement because
# it will set the type of input arguments to static shapes, and this
# can invalidate the module if the argument is used as the result of a
# function, or if it appears as the input to a custom_call with
# output_operand_alias attribute. See b/287386268.
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
call = func_dialect.CallOp(orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args)
func_dialect.ReturnOp([call.results[idx] for idx in new_main_result_indices])
symbol_table.set_symbol_name(new_main_op, "main")
return wrapped_module
def _check_lowering(lowering) -> None:
if not isinstance(lowering, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = [
"backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"mut", "spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
# We have not implemented support for some of the compile_args. Check here that
# the compile_args have the values that have been implemented.
not_implemented_msgs = []
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
# tuple_args is a compilation flag, does not affect lowering.
("tuple_args", lambda v: True, "N/A"),
# unordered_effects do not change the calling convention. Those from
# jax.debug will also result in keepalive being non-empty and unsupported
# custom calls. The CallTfEffect is an exception, but we want to allow
# that one.
("unordered_effects", lambda v: True, "N/A"),
("ordered_effects", lambda v: True, "N/A"),
# used for TPU jax.debug, send/recv. Not supported yet.
("host_callbacks", lambda v: not v, "empty"),
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
("shape_poly_state", lambda v: True, "N/A"),
):
if compile_arg in lowering.compile_args:
if not check_value(lowering.compile_args[compile_arg]):
not_implemented_msgs.append(
f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}")
if not_implemented_msgs:
raise NotImplementedError(
"serialization error, unimplemented lowered.compile_args:\n" +
"\n".join(not_implemented_msgs))
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"dynamic_ducc_fft", "cu_threefry2x32",
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
# householder product on CPU
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
# svd on CPU
"lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd",
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# triangular_solve on CPU
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
# schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# # lu on GPU
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
# lu on TPU
"LuDecomposition",
# ApproxTopK on TPU
"ApproxTopK",
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
"tpu_custom_call", # Pallas/TPU kernels
# TODO(burmako): maintain backwards compatibility for these, until they
# are upstreamed to StableHLO.
# See https://github.com/openxla/stablehlo/issues/8.
"stablehlo.dynamic_reduce_window",
"stablehlo.dynamic_rng_bit_generator",
"stablehlo.dynamic_top_k",
"shape_assertion", # Used by shape_poly to evaluate assertions
}
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")
def _check_module(mod: ir.Module, *,
allow_non_replicated_sharding: bool,
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
"""Run a number of checks on the module.
Args:
allow_non_replicated_sharding: whether the module is allowed to contain
non_replicated sharding annotations.
disabled_checks: the safety checks that are disabled.
"""
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
for dc in disabled_checks:
target = dc.is_custom_call()
if target is not None:
allowed_custom_call_targets.add(target)
allowed_custom_call_targets_attrs = {
ir.StringAttr.get(target, mod.context)
for target in allowed_custom_call_targets}
disallowed_custom_call_ops: list[str] = []
def check_sharding(op: ir.Operation, loc: ir.Location):
if not allow_non_replicated_sharding:
try:
sharding = op.attributes["mhlo.sharding"]
except KeyError:
pass
else:
if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value):
raise ValueError(
"Lowered function does not have a top-level pjit but it has"
f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee"
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning"
" for a discussion."
)
def check_op(op: ir.Operation):
op_name = op.operation.name
if op_name == "func.func":
check_sharding(op.operation, op.location)
elif op_name == "stablehlo.custom_call":
call_target_name_attr = op.operation.attributes["call_target_name"]
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(f"{op} at {op.location}")
if call_target_name_attr == sharding_attr:
check_sharding(op, op.location)
elif call_target_name_attr == shape_assertion_attr:
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)
def walk_operations(op):
check_op(op)
for region in op.operation.regions:
for block in region:
for op in block:
walk_operations(op)
walk_operations(mod)
if disallowed_custom_call_ops:
disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops)
msg = ("Cannot serialize code with custom calls whose targets have no "
"compatibility guarantees. Examples are:\n"
f"{disallowed_custom_call_ops_str}.\n"
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
module_kept_var_idx: Sequence[int],
nr_inputs: int) -> Sequence[LoweringSharding]:
"""Expands in_shardings with unspecified shardings for inputs not kept.
Assumes in_shardings corresponds to module_kept_var_idx.
"""
assert len(in_shardings) == len(module_kept_var_idx)
assert nr_inputs >= len(module_kept_var_idx)
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
all_in_shardings[idx] = in_s
return tuple(all_in_shardings)
# TODO(yashkatariya, necula): remove this function once we relax the checks
# in the jit front-end.
def canonical_shardings(
device_assignment: Sequence[jax.Device],
in_shardings: Sequence[Sharding],
out_shardings: Sequence[Sharding]
) -> tuple[(pxla.UnspecifiedValue |
Sequence[sharding.XLACompatibleSharding]),
(pxla.UnspecifiedValue |
Sequence[sharding.XLACompatibleSharding])]:
"""Prepares canonical in_ and out_shardings for a pjit invocation.
The pjit front-end is picky about what in- and out-shardings it accepts,
e.g., if all are unspecified then the whole sharding should be the
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
replaced with the replicated sharding.
Returns: a pair with the canonicalized input and output shardings.
"""
replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment)
def canonicalize(
ss: Sequence[Sharding]) -> (pxla.UnspecifiedValue |
Sequence[sharding.XLACompatibleSharding]):
if all(s is None for s in ss):
return sharding_impls.UNSPECIFIED
return tuple(
sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s
for s in ss)
return (canonicalize(in_shardings), canonicalize(out_shardings))
def _get_vjp_fun(primal_fun: Callable, *,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
in_shardings: tuple[Sharding, ...],
out_shardings: tuple[Sharding, ...],
nr_devices: int,
apply_jit: bool
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
# here with flattened functions.
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
def flattened_primal_fun_jax(*args_flat):
args, kwargs = in_tree.unflatten(args_flat)
res = primal_fun(*args, **kwargs)
res_flat, _ = tree_util.tree_flatten(res)
return res_flat
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
[len(in_avals)])
_, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = list(
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))
if apply_jit:
# Prepare a device assignment. For exporting purposes, all it matters
# is the number of devices.
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
assert len(device_assignment) == nr_devices
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
device_assignment,
tuple(itertools.chain(in_shardings, out_shardings)),
in_shardings)
return pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_out_shardings), vjp_in_avals
else:
return fun_vjp_jax, vjp_in_avals
def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
in_tree=primal.in_tree,
in_avals=primal.in_avals,
in_shardings=primal.in_shardings,
out_avals=primal.out_avals,
out_shardings=primal.out_shardings,
nr_devices=primal.nr_devices,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platforms=primal.lowering_platforms,
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)
### Calling the exported function
def call(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
f"Found {exported}.")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)
def f_flat_vjp_fwd(*args_flat):
# Return the primal arguments as the residual
# TODO: keep as residuals only the arguments that are needed
return f_flat(*args_flat), args_flat
def f_flat_vjp_bwd(residual, ct_res_flat):
args_flat = residual # residual is the primal argument flat tuple
exp_vjp = exported.vjp()
# ct_res_flat may contain arrays of zeros where exp_vjp expect float0.
# We make the proper arrays of float0 to invoke exp_vjp.
def fix_float0_ct(ct_res, expected_aval):
if expected_aval.dtype != dtypes.float0:
return ct_res
return ad_util.zeros_like_aval(expected_aval)
ct_res_fixed = map(fix_float0_ct,
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_fixed)
return in_ct_flat
f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd)
def f_imported(*args, **kwargs):
# since custom_vjp does not support kwargs, flatten the function first.
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != exported.in_tree:
# Give errors with the precise tree difference; use fake leaves so we can
# use tree_util.equality_errors.
in_args = in_tree.unflatten([0] * in_tree.num_leaves)
exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves)
msg = (
"The invocation args and kwargs must have the same pytree structure "
f"as when the function '{exported.fun_name}' was exported, but they "
"have the following structural differences:\n" +
("\n".join(
f" - {_shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a "
f"{thing2} when exported, so {explanation}.\n"
for path, thing1, thing2, explanation
in tree_util.equality_errors(in_args, exp_in_args))))
raise ValueError(msg)
res_flat = f_flat(*args_flat)
return exported.out_tree.unflatten(res_flat)
return f_imported
call_exported = call
# A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported")
call_exported_p.multiple_results = True
@util.cache()
def _call_exported_abstract_eval(
*in_avals: core.AbstractValue,
exported: Exported
) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]:
exported_dim_vars = _shape_poly.all_dim_vars(exported.in_avals)
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
# Check that the expected shapes match the actual ones
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
def pp_arg_dim(dim_idx: int | None) -> str:
return _shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
arg_idx, dim_idx)
if len(exp_aval.shape) != len(actual_aval.shape):
raise ValueError(
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
f"and called with {actual_aval.shape}")
if exp_aval.dtype != actual_aval.dtype:
raise ValueError(
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
f"and called with {actual_aval.dtype}")
for dim_idx, aval_d in enumerate(exp_aval.shape):
# If the exp_aval has a constant dimension then the actual argument must have
# a matching constant dimension.
if core.is_constant_dim(aval_d):
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
aval_d != actual_aval.shape[dim_idx]):
raise ValueError(
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
"(expected same constant): "
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
# Must express the exported_dim_vars in terms of the shapes in in_avals.
solution, shape_constraints, synth_dim_vars = _shape_poly.solve_dim_vars(
exported.in_avals, args_kwargs_tree=exported.in_tree)
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in synth_dim_vars}
synthetic_eval = _shape_poly.CachingShapeEvaluator(**synthetic_env)
# We discharge all the constraints statically. This results in much simpler
# composability (because we do not have to worry about the constraints of the
# Exported called recursively; we only need to worry about entry-point
# constraints). This also makes sense from a composability point of view,
# because we get the same errors if we invoke the exported module, or if we
# trace the exported function. Consider for example, an exported module with
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
# succeed and add a compile-time check that `c == d`. In the latter case,
# it would be ambiguous whether we should continue tracing with a result
# of type `f32[c]` or `f32[d]`.
shape_constraints.check_statically(synthetic_eval)
exported_dim_values = [synthetic_eval.evaluate(solution[var])
for var in exported_dim_vars]
out_avals = tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
named_shape=out_aval.named_shape)
for out_aval in exported.out_avals)
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
call_exported_p.def_effectful_abstract_eval(_call_exported_abstract_eval)
def _call_exported_impl(*args, exported: Exported):
return dispatch.apply_primitive(call_exported_p, *args, exported=exported)
call_exported_p.def_impl(_call_exported_impl)
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
exported: Exported):
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
num_devices = axis_context.num_devices
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
num_devices = axis_context.mesh.size
else:
raise NotImplementedError(type(axis_context))
if num_devices != exported.nr_devices:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices"
)
# Apply in_shardings
args = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
submodule = ir.Module.parse(exported.mlir_module())
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
# is valid.
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
else:
return x
callee_type = symtab["main"].type
# TODO: maybe cache multiple calls
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_exported_{exported.fun_name}",
submodule)
submodule_args = []
# All the platforms for the current lowering must be among the platforms
# for which the callee was lowered.
lowering_platforms = ctx.module_context.platforms
callee_lowering_platform_index: list[int] = []
for platform in lowering_platforms:
if platform in exported.lowering_platforms:
callee_lowering_platform_index.append(
exported.lowering_platforms.index(platform))
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
callee_lowering_platform_index.append(0)
else:
raise ValueError(
f"The exported function '{exported.fun_name}' was lowered for "
f"platforms '{exported.lowering_platforms}' but it is used "
f"on '{lowering_platforms}'.")
if len(exported.lowering_platforms) > 1:
# The exported module takes a platform index argument
if len(lowering_platforms) > 1:
current_platform_idx = ctx.dim_var_values[0]
else:
current_platform_idx = mlir.ir_constant(np.int32(0))
# Compute the rule index based on the current platform
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
callee_platform_idx = hlo.CaseOp([i32_type],
index=current_platform_idx,
num_branches=len(lowering_platforms))
for i in range(len(lowering_platforms)):
branch = callee_platform_idx.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.return_(mlir.ir_constants(
np.int32(callee_lowering_platform_index[i])))
if callee_platform_idx.result.type != callee_type.inputs[0]:
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
callee_platform_idx)
submodule_args.append(callee_platform_idx)
else:
assert len(lowering_platforms) == 1
if _keep_main_tokens(exported.mlir_module_serialization_version):
ordered_effects = exported.ordered_effects
else:
ordered_effects = ()
for eff in ordered_effects:
token_in = ctx.tokens_in.get(eff)[0]
submodule_args.append(token_in)
kept_args = [
convert_shape(a, a_aval, exported_in_aval)
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
if i in exported.module_kept_var_idx]
submodule_args = submodule_args + kept_args
call = func_dialect.CallOp(callee_type.results,
ir.FlatSymbolRefAttr.get(fn),
submodule_args)
if ordered_effects:
tokens_out = {eff: (call.results[effect_idx],)
for effect_idx, eff in enumerate(ordered_effects)}
ctx.set_tokens_out(mlir.TokenSet(tokens_out))
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
results = tuple(
convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):],
exported.out_avals, ctx.avals_out))
# Apply out_shardings
results = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
)
return results
mlir.register_lowering(call_exported_p, _call_exported_lowering)
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: Sharding) -> ir.Value:
if x_sharding is None:
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(_shape_poly.symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version