1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 23:56:06 +00:00
2024-06-05 09:07:27 -07:00

1338 lines
59 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX APIs for exporting JAX functions for interoperation.
"""
from __future__ import annotations
from collections.abc import Sequence
import copy
import dataclasses
import functools
import itertools
import re
from typing import Any, Callable, Union
import warnings
from absl import logging
import numpy as np
import jax
from jax import sharding
from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.export import shape_poly
map = util.safe_map
zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
HloSharding = xla_client.HloSharding
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 9
maximum_supported_serialization_version = 9
class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred to
deserialization. The list of disabled checks is attached to the serialization,
e.g., as a sequence of string attributes to `jax_export.Exported` or of
`tf.XlaCallModuleOp`.
You can disable more deserialization safety checks by passing
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
"""
_impl: str
@classmethod
def platform(cls) -> DisabledSafetyCheck:
"""Allows the execution platform to differ from the serialization platform.
Has effect only on deserialization.
"""
return DisabledSafetyCheck("platform")
@classmethod
def custom_call(cls, target_name: str) -> DisabledSafetyCheck:
"""Allows the serialization of a call target not known to be stable.
Has effect only on serialization.
Args:
target_name: the name of the custom call target to allow.
"""
return DisabledSafetyCheck(f"custom_call:{target_name}")
@classmethod
def shape_assertions(cls) -> DisabledSafetyCheck:
"""A noop. DEPRECATED.
Was used previously to allow invocations with shapes that do not meet the
constraints. Has no effect anymore, shape assertions cannot be disabled.
"""
# TODO(necula): remove this after compatibility period. Was deprecated in
# May 2024.
warnings.warn(
"DisabledSafetyCheck.shape_assertions is deprecated, has no effect anymore",
DeprecationWarning, stacklevel=2)
return DisabledSafetyCheck("shape_assertions")
def is_custom_call(self) -> str | None:
"""Returns the custom call target allowed by this directive."""
m = re.match(r'custom_call:(.+)$', self._impl)
return m.group(1) if m else None
def __init__(self, _impl:str):
# Do not use directly, use builders `platform`, `custom_call`.
self._impl = _impl
def __str__(self):
return self._impl
__repr__ = __str__
def __eq__(self, other) -> bool:
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl
def __hash__(self) -> int:
return hash(self._impl)
@dataclasses.dataclass(frozen=True)
class Exported:
"""A JAX function lowered to StableHLO.
Attributes:
fun_name: the name of the exported function, for error messages.
in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX
function. The actual lowering does not depend on the `in_tree`, but this
can be used to invoke the exported function using the same argument
structure.
in_avals: the flat tuple of input abstract values. May contain dimension
expressions in the shapes.
out_tree: a PyTreeDef describing the result of the lowered JAX function.
out_avals: the flat tuple of output abstract values. May contain dimension
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings, as long as `in_avals`.
`None` means unspecified sharding.
Note that these do not include the mesh or the actual devices used in
the mesh. See `xla_compatible_in_shardings` for a way to turn these
into sharding specification that can be used with JAX APIs.
out_shardings: the flattened output shardings, as long as `out_avals`.
`None` means unspecified sharding.
Note that these do not include the mesh or the actual devices used in
the mesh. See `xla_compatible_out_shardings` for a way to turn these
into sharding specification that can be used with JAX APIs.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
there are multiple lowering platforms.
ordered_effects: the ordered effects present in the serialized module.
This is present from serialization version 9. See below for the
calling convention in presence of ordered effects.
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
mlir_module_serialization_version: a version number for the serialized module.
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used.
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
polymorphism. This may be because `in_avals` contains dimension
variables, or due to inner calls of Exported modules that have
dimension variables or platform index arguments. Such modules need
shape refinement before XLA compilation.
disabled_safety_checks: a list of descriptors of safety checks that have been
disabled at export time. See docstring for `DisabledSafetyCheck`.
_get_vjp: an optional function that takes the current exported function and
returns the exported VJP function.
The VJP function takes a flat list of arguments,
starting with the primal arguments and followed by a cotangent argument
for each primal output. It returns a tuple with the cotangents
corresponding to the flattened primal inputs.
Calling convention for the exported module (for latest supported version):
The `mlir_module` has a `main` function that takes an optional first
platform index argument if the module supports multiple platforms
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
to the ordered effects, followed by the kept array
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
The platform index is a i32 or i64 scalar encoding the index of the current
compilation platform into the `lowering_platforms` sequence.
Inner functions use a different calling convention: an optional
platform index argument, optional dimension variable arguments
(scalar tensors of type i32 or i64),
followed by optional token arguments (in presence of ordered effects),
followed by the regular array arguments.
The dimension arguments correspond to the dimension variables appearing in
the `args_avals`, in sorted order of their names.
Consider the lowering of a function with one array argument of type "f32[w,
2 * h]", where "w" and "h" are two dimension variables.
Assume that we use multi-platform lowering, and we have
one ordered effect. The `main` function will be as follows:
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
token_in: token,
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return token_out, res
}
The actual computation is in `_wrapped_jax_export_main`, taking also
the values of `h` and `w` dimension variables.
The signature of the `_wrapped_jax_export_main` is:
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?]) -> (stablehlo.token, ...)
Prior to serialization version 9 the calling convention for effects is
different: the `main` function does not take or return a token. Instead
the function creates dummy tokens of type `i1[0]` and passes them to the
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
takes dummy tokens of type `i1[0]` and will create internally real
tokens to pass to the inner functions. The inner functions use real
tokens (both before and after serialization version 9)
Also starting with serialization version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
The global constant name may be empty if it is not known.
Some global constant computations use inner functions, e.g., for
`floor_divide`. The arguments of such functions have a `jax.global_constant`
attribute for all attributes, meaning that the result of the function is
also a global constant.
Note that `main` contains a call to `_check_shape_assertions.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
module. We use a special custom call `@shape_assertion` that takes
a boolean first operand, a string `error_message` attribute that may contain
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
scalar operands corresponding to the format specifiers.
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
If we `call_exported` with this module we perform these checks
statically (in `call_exported_abstract_eval`).
"""
fun_name: str
in_tree: tree_util.PyTreeDef
in_avals: tuple[core.AbstractValue, ...]
out_tree: tree_util.PyTreeDef
out_avals: tuple[core.AbstractValue, ...]
in_shardings: tuple[HloSharding | None, ...]
out_shardings: tuple[HloSharding | None, ...]
nr_devices: int
lowering_platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
disabled_safety_checks: Sequence[DisabledSafetyCheck]
mlir_module_serialized: bytes
mlir_module_serialization_version: int
module_kept_var_idx: tuple[int, ...]
uses_shape_polymorphism: bool
_get_vjp: Callable[[Exported], Exported] | None
def mlir_module(self) -> str:
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def __str__(self):
# This is called to make a MLIR source location when we call an Exported, and we
# do not want the entire serialized module to end up in locations.
return f"Exported(fun_name={self.fun_name}, ...)"
def xla_compatible_in_shardings(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.in_shardings.
The Exported object stores `in_shardings` as HloShardings, which are
independent of a mesh or set of devices. This method constructs
Sharding that can be used in JAX APIs such as `jax.jit` or
`jax.device_put`.
Example usage:
>>> from jax.experimental import export
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
... )(np.arange(jax.device_count()))
>>> exp.in_shardings
({devices=[8]<=[8]},)
# Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>>
# Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.xla_compatible_in_shardings(run_mesh)[0])
>>> res = export.call(exp)(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
"""
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
for s in self.in_shardings)
def xla_compatible_out_shardings(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.out_shardings.
See documentation for xla_compatible_in_shardings.
"""
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
for s in self.out_shardings)
def has_vjp(self) -> bool:
return self._get_vjp is not None
def vjp(self) -> Exported:
"""Gets the exported VJP.
Returns None if not available, which can happen if the Exported has been
loaded from an external format, without a VJP."""
if self._get_vjp is None:
raise ValueError("No VJP is available")
return self._get_vjp(self)
def default_lowering_platform() -> str:
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
return aval.shape, aval.dtype
def args_specs(
args, # pytree of arguments
polymorphic_shapes, # prefix pytree of strings
get_shape_and_dtype=shape_and_dtype_jax_array,
):
# TODO: deprecated in January 2024, to be removed.
warnings.warn(
"export.args_specs is deprecated in favor of export.symbolic_args_specs",
DeprecationWarning, stacklevel=2)
if get_shape_and_dtype is not shape_and_dtype_jax_array:
# This was needed in some older jax2tf implementations
args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)),
args)
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
def export(fun_jax: Callable,
*,
lowering_platforms: Sequence[str] | None = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
_device_assignment_for_internal_jax2tf_use_only = None,
) -> Callable[..., Exported]:
"""Exports native serialization for a JAX function.
Args:
fun_jax: the function to lower and serialize.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained in the
`jax_export.Exported` docstring.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:
def f_jax(*args, **kwargs): ...
exported = jax_export.export(f_jax)(*args, **kwargs)
"""
def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
wrapped_fun_jax = jax.jit(fun_jax)
else:
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax # type: ignore
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)
# TODO: move to `lower`
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may have no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}",
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
))
return _export_lowered(
lowered, disabled_checks=disabled_checks,
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
return do_export
def _export_lowered(
lowered: stages.Lowered,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
_device_assignment_for_internal_jax2tf_use_only = None,
) -> Exported:
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")
lowering = lowered._lowering
_check_lowering(lowering)
mlir_module = lowering.stablehlo()
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "mut" in lowering.compile_args:
if lowering.compile_args["mut"]: raise NotImplementedError
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
module_kept_var_idx=module_kept_var_idx,
serialization_version=version)
with mlir_module.context:
mlir_module_attrs = mlir_module.operation.attributes
mlir_module_attrs["jax.uses_shape_polymorphism"] = (
mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars))
mlir_module_serialized = _module_to_bytecode(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowering.compile_args:
# This is currently the case for pjit
out_avals_flat = lowering.compile_args["global_out_avals"]
elif "shards" in lowering.compile_args: # for PmapComputation
out_avals_flat = lowering.compile_args["shards"].out_sharded_avals
else:
out_avals_flat = lowered.compile_args["out_avals"]
# Log and then check the module.
if logging.vlog_is_on(3):
logmsg = (f"version={version} "
f"lowering_platforms={lowering.compile_args['platforms']} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
logging.info("Dumped the exported MLIR module to %s", dumped_to)
_check_module(mlir_module,
disabled_checks=disabled_checks)
ordered_effects = tuple(lowering.compile_args["ordered_effects"])
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> HloSharding | None:
if sharding_impls.is_unspecified(s):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
len(args_avals_flat))
in_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(all_in_shardings, args_avals_flat))
out_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
device_assignment = lowering.compile_args["device_assignment"]
if _device_assignment_for_internal_jax2tf_use_only is not None:
_device_assignment_for_internal_jax2tf_use_only[0] = device_assignment
def _get_exported_vjp(exp_primal: Exported) -> Exported:
# Turn the primal jaxpr into a function, in preparation for exporting
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments
assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX
fun_jax = core.jaxpr_as_fun(lowered._jaxpr)
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax,
in_tree=exp_primal.in_tree,
in_avals=exp_primal.in_avals,
in_shardings=exp_primal.in_shardings,
out_avals=exp_primal.out_avals,
out_shardings=exp_primal.out_shardings,
device_assignment=device_assignment,
apply_jit=True,
flat_primal_fun=True)
return export(fun_vjp_jax,
lowering_platforms=exp_primal.lowering_platforms,
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
return Exported(
fun_name=lowered._fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
out_avals=tuple(out_avals_flat),
in_shardings=in_shardings,
out_shardings=out_shardings,
nr_devices=nr_devices,
lowering_platforms=lowering._platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
mlir_module_serialization_version=version,
_get_vjp=_get_exported_vjp)
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `hlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = hlo.get_minimum_version()
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return module_serialized
def _wrap_main_func(
module: ir.Module,
args_avals_flat: Sequence[core.ShapedArray],
*,
args_kwargs_tree: tree_util.PyTreeDef,
has_platform_index_argument: bool,
module_kept_var_idx: tuple[int, ...],
serialization_version: int
) -> ir.Module:
"""Wraps the lowered module with a new "main" handling dimension arguments.
See calling convention documentation for `jax_export.Exported`.
Args:
module: the HLO module as obtained from lowering. See the calling convention
for inner functions in `jax_export.Exported`.
args_avals_flat: the avals for all the arguments of the lowered function,
which correspond to the array arguments of the `module`.
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
messages.
has_platform_index_argument: whether the `module` has a first platform
index argument
module_kept_var_idx: a sorted tuple of integers with the indices of arguments
in `args_avals_flat` that are kept as `module` arguments.
serialization_version: the target serialization version
Returns the wrapped module, without dimension and token arguments.
"""
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
# Make a copy, do not mutate because it may be cached
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
def is_token(typ, attrs):
return (typ == mlir.token_type()[0])
orig_input_types = orig_main.type.inputs
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs))
# The order of args: platform_index_arg, dim args, token args, array args.
nr_platform_index_args = 1 if has_platform_index_argument else 0
nr_dim_args = len(dim_vars)
token_arg_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_input_types,
arg_attrs))
if is_token(typ, attrs)]
nr_token_args = len(token_arg_idxs)
if nr_token_args > 0:
assert min(token_arg_idxs) == nr_platform_index_args + nr_dim_args
assert token_arg_idxs == list(
range(nr_platform_index_args + nr_dim_args,
nr_platform_index_args + nr_dim_args + nr_token_args))
nr_array_args = (len(orig_input_types) - nr_platform_index_args
- nr_dim_args - nr_token_args)
assert nr_array_args >= 0
(platform_input_types, dim_var_input_types,
token_input_types, array_input_types) = util.split_list(
orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args])
# The order of results: tokens, array results
orig_output_types = orig_main.type.results
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types,
result_attrs))
if is_token(typ, attrs)]
nr_token_results = len(token_result_idxs)
assert token_result_idxs == list(range(0, nr_token_results))
nr_array_results = len(orig_output_types) - nr_token_results
assert nr_array_results >= 0
new_main_arg_indices = (
*range(nr_platform_index_args),
*range(nr_platform_index_args + nr_dim_args, len(orig_input_types)))
new_main_result_indices = tuple(range(0, len(orig_output_types)))
new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices]
new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices]
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
new_main_op = func_dialect.FuncOp(
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
try:
new_arg_attrs = []
for idx in new_main_arg_indices:
new_arg_attr = {}
for attr in arg_attrs[idx]:
if attr.name == "tf.aliasing_output":
i = new_main_result_indices.index(attr.attr.value)
new_arg_attr[attr.name] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), i
)
else:
new_arg_attr[attr.name] = attr.attr
new_arg_attrs.append(ir.DictAttr.get(new_arg_attr))
new_main_op.arg_attrs = ir.ArrayAttr.get(new_arg_attrs)
except KeyError:
pass # TODO: better detection if orig_main.arg_attrs does not exist
try:
new_main_op.result_attrs = ir.ArrayAttr.get(
[result_attrs[idx] for idx in new_main_result_indices])
except KeyError:
pass
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"],
axis_context=sharding_impls.ShardingContext(0),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
lowering_parameters=mlir.LoweringParameters(
global_constant_computation=True
))
ctx = mlir.LoweringRuleContext(
module_context=module_context,
name_stack=source_info_util.new_name_stack(), primitive=None,
avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
# We compute dim_values from the array arguments.
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
if shape_poly.all_dim_vars(args_avals_flat):
# TODO(necula): handle module_kept_var_idx in presence of shape
# polymorphism. For now we ensured upstream that we keep all variables.
assert len(set(module_kept_var_idx)) == len(args_avals_flat)
dim_values = mlir.lower_fun(
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *new_main_op_array_args)
else:
dim_values = ()
# The arguments to pass to the call to orig_main
orig_main_args: list[ir.Value] = []
# The platform index and the dimension variables
for arg, arg_type in zip(
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
platform_input_types + dim_var_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
# Then the token arguments
orig_main_args.extend(
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
# Then the array arguments. We insert a ConvertOp as the only use of
# an input argument. This helps the downstream shape refinement because
# it will set the type of input arguments to static shapes, and this
# can invalidate the module if the argument is used as the result of a
# function, or if it appears as the input to a custom_call with
# output_operand_alias attribute. See b/287386268.
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
call = func_dialect.CallOp(orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args)
func_dialect.ReturnOp([call.results[idx] for idx in new_main_result_indices])
symbol_table.set_symbol_name(new_main_op, "main")
return wrapped_module
def _check_lowering(lowering) -> None:
if not isinstance(lowering, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = [
"backend", "platforms", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"mut", "spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
"pgle_profiler"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
# We have not implemented support for some of the compile_args. Check here that
# the compile_args have the values that have been implemented.
not_implemented_msgs = []
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
# tuple_args is a compilation flag, does not affect lowering.
("tuple_args", lambda v: True, "N/A"),
# unordered_effects do not change the calling convention. Those from
# jax.debug will also result in keepalive being non-empty and unsupported
# custom calls. The CallTfEffect is an exception, but we want to allow
# that one.
("unordered_effects", lambda v: True, "N/A"),
("ordered_effects", lambda v: True, "N/A"),
# used for TPU jax.debug, send/recv. Not supported yet.
("host_callbacks", lambda v: not v, "empty"),
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
("shape_poly_state", lambda v: True, "N/A"),
):
if compile_arg in lowering.compile_args:
if not check_value(lowering.compile_args[compile_arg]):
not_implemented_msgs.append(
f"{compile_arg} must be {err_msg} and it is {lowering.compile_args[compile_arg]}")
if not_implemented_msgs:
raise NotImplementedError(
"serialization error, unimplemented lowered.compile_args:\n" +
"\n".join(not_implemented_msgs))
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"cu_threefry2x32",
"__gpu$xla.gpu.triton", # Pallas call on GPU
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
# householder product on CPU
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
# svd on CPU
"lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd",
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# triangular_solve on CPU
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
# schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# # lu on GPU
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
# lu on TPU
"LuDecomposition",
# ApproxTopK on TPU
"ApproxTopK",
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
"tpu_custom_call", # Pallas/TPU kernels
# TODO(burmako): maintain backwards compatibility for these, until they
# are upstreamed to StableHLO.
# See https://github.com/openxla/stablehlo/issues/8.
"stablehlo.dynamic_reduce_window",
"stablehlo.dynamic_rng_bit_generator",
"stablehlo.dynamic_top_k",
"shape_assertion", # Used by shape_poly to evaluate assertions
}
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")
def _check_module(mod: ir.Module, *,
disabled_checks: Sequence[DisabledSafetyCheck]) -> bool:
"""Run a number of checks on the module.
Args:
disabled_checks: the safety checks that are disabled.
Returns True if the module uses non-replicated shardings.
"""
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
for dc in disabled_checks:
target = dc.is_custom_call()
if target is not None:
allowed_custom_call_targets.add(target)
allowed_custom_call_targets_attrs = {
ir.StringAttr.get(target, mod.context)
for target in allowed_custom_call_targets}
disallowed_custom_call_ops: list[str] = []
module_uses_non_replicated_sharding = False
def check_sharding(op: ir.Operation, loc: ir.Location):
try:
sharding = op.attributes["mhlo.sharding"]
except KeyError:
pass
else:
nonlocal module_uses_non_replicated_sharding
try:
sharding_value = ir.StringAttr(sharding).value
except UnicodeDecodeError:
# The mhlo.sharding attribute may be in pretty-printed format, or
# as an encoding of an HloSharding protobuf in some rare situations.
# We handle the latter by conservatively assuming it is non-replicated.
module_uses_non_replicated_sharding = True
else:
if not re.match(check_sharding_pattern, sharding_value):
module_uses_non_replicated_sharding = True
def check_op(op: ir.Operation):
op_name = op.operation.name
if op_name == "func.func":
check_sharding(op.operation, op.location)
elif op_name == "stablehlo.custom_call":
call_target_name_attr = op.operation.attributes["call_target_name"]
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(f"{op} at {op.location}")
if call_target_name_attr == sharding_attr:
check_sharding(op, op.location)
def walk_operations(op):
check_op(op)
for region in op.operation.regions:
for block in region:
for op in block:
walk_operations(op)
walk_operations(mod)
if disallowed_custom_call_ops:
disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops)
msg = ("Cannot serialize code with custom calls whose targets have no "
"compatibility guarantees. Examples are:\n"
f"{disallowed_custom_call_ops_str}.\n"
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)
return module_uses_non_replicated_sharding
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
module_kept_var_idx: Sequence[int],
nr_inputs: int) -> Sequence[LoweringSharding]:
"""Expands in_shardings with unspecified shardings for inputs not kept.
Assumes in_shardings corresponds to module_kept_var_idx.
"""
assert len(in_shardings) == len(module_kept_var_idx)
assert nr_inputs >= len(module_kept_var_idx)
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
all_in_shardings[idx] = in_s
return tuple(all_in_shardings)
def _hlo_sharding_to_xla_compatible_sharding(
hlo_sharding: HloSharding | None,
mesh: sharding.Mesh) -> sharding.Sharding | None:
if hlo_sharding is None:
return None
return sharding_impls._gspmd_to_named_sharding_via_mesh(
_hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type]
mesh)
def _hlo_sharding_to_gspmd_sharding(
hlo_sharding: HloSharding | None,
device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None:
if hlo_sharding is None:
return None
return sharding.GSPMDSharding(device_assignment, hlo_sharding)
def _get_vjp_fun(primal_fun: Callable, *,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
in_shardings: tuple[HloSharding | None, ...],
out_shardings: tuple[HloSharding | None, ...],
device_assignment: Sequence[sharding_impls.Device] | None,
apply_jit: bool,
flat_primal_fun: bool = False,
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
# here with flattened functions.
# apply_jit=False is only used for backwards compatibility with the graph
# graph serialization. When apply_jit=True, we must pass a device assignment.
# flat_primal_fun=False is used only from jax2tf, and it means that the
# `primal_fun` takes PyTree `*args` and `**kwargs`.
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
def flattened_primal_fun_jax(*args_flat):
args, kwargs = in_tree.unflatten(args_flat)
res = primal_fun(*args, **kwargs)
res_flat, _ = tree_util.tree_flatten(res)
return res_flat
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
[len(in_avals)])
_, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax,
*args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = list(
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))
if apply_jit:
assert device_assignment is not None
vjp_in_shardings = tuple(
_hlo_sharding_to_gspmd_sharding(s, device_assignment)
for s in itertools.chain(in_shardings, out_shardings))
vjp_out_shardings = tuple(
_hlo_sharding_to_gspmd_sharding(s, device_assignment)
for s in in_shardings)
return pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_out_shardings), vjp_in_avals
else:
return fun_vjp_jax, vjp_in_avals
### Calling the exported function
def call(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
f"Found {exported}.")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)
def f_flat_vjp_fwd(*args_flat):
# Return the primal arguments as the residual
# TODO: keep as residuals only the arguments that are needed
return f_flat(*args_flat), args_flat
def f_flat_vjp_bwd(residual, ct_res_flat):
args_flat = residual # residual is the primal argument flat tuple
exp_vjp = exported.vjp()
# ct_res_flat may contain arrays of zeros where exp_vjp expect float0.
# We make the proper arrays of float0 to invoke exp_vjp.
def fix_float0_ct(ct_res, expected_aval):
if expected_aval.dtype != dtypes.float0:
return ct_res
return ad_util.zeros_like_aval(expected_aval)
ct_res_fixed = map(fix_float0_ct,
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_fixed)
return in_ct_flat
f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd)
def f_imported(*args, **kwargs):
# since custom_vjp does not support kwargs, flatten the function first.
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != exported.in_tree:
# Give errors with the precise tree difference; use fake leaves so we can
# use tree_util.equality_errors.
in_args = in_tree.unflatten([0] * in_tree.num_leaves)
exp_in_args = exported.in_tree.unflatten([0] * exported.in_tree.num_leaves)
msg = (
"The invocation args and kwargs must have the same pytree structure "
f"as when the function '{exported.fun_name}' was exported, but they "
"have the following structural differences:\n" +
("\n".join(
f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a "
f"{thing2} when exported, so {explanation}.\n"
for path, thing1, thing2, explanation
in tree_util.equality_errors(in_args, exp_in_args))))
raise ValueError(msg)
res_flat = f_flat(*args_flat)
return exported.out_tree.unflatten(res_flat)
return f_imported
call_exported = call
# A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported")
call_exported_p.multiple_results = True
@util.cache()
def _call_exported_abstract_eval(
*in_avals: core.AbstractValue,
exported: Exported
) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]:
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
# Check that the expected shapes match the actual ones
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
def pp_arg_dim(dim_idx: int | None) -> str:
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
arg_idx, dim_idx)
if len(exp_aval.shape) != len(actual_aval.shape):
raise ValueError(
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
f"and called with {actual_aval.shape}")
if exp_aval.dtype != actual_aval.dtype:
raise ValueError(
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
f"and called with {actual_aval.dtype}")
for dim_idx, aval_d in enumerate(exp_aval.shape):
# If the exp_aval has a constant dimension then the actual argument must have
# a matching constant dimension.
if core.is_constant_dim(aval_d):
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
aval_d != actual_aval.shape[dim_idx]):
raise ValueError(
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
"(expected same constant): "
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
# Must express the exported_dim_vars in terms of the shapes in in_avals.
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
exported.in_avals, args_kwargs_tree=exported.in_tree)
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in synth_dim_vars}
synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env)
# We discharge all the constraints statically. This results in much simpler
# composability (because we do not have to worry about the constraints of the
# Exported called recursively; we only need to worry about entry-point
# constraints). This also makes sense from a composability point of view,
# because we get the same errors if we invoke the exported module, or if we
# trace the exported function. Consider for example, an exported module with
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
# succeed and add a compile-time check that `c == d`. In the latter case,
# it would be ambiguous whether we should continue tracing with a result
# of type `f32[c]` or `f32[d]`.
shape_constraints.check_statically(synthetic_eval)
exported_dim_values = [synthetic_eval.evaluate(solution[var])
for var in exported_dim_vars]
out_avals = tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
named_shape=out_aval.named_shape)
for out_aval in exported.out_avals)
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
call_exported_p.def_effectful_abstract_eval(_call_exported_abstract_eval)
def _call_exported_impl(*args, exported: Exported):
return dispatch.apply_primitive(call_exported_p, *args, exported=exported)
call_exported_p.def_impl(_call_exported_impl)
def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
exported: Exported):
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True
submodule = ir.Module.parse(exported.mlir_module())
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
num_devices = axis_context.num_devices
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
num_devices = axis_context.mesh.size
elif isinstance(axis_context, sharding_impls.ReplicaAxisContext):
num_devices = axis_context.axis_env.nreps
else:
raise NotImplementedError(type(axis_context))
if num_devices != exported.nr_devices:
# In some special cases we allow running with a different number of devices
# than the function was exported for.
err_msg = ""
if exported.nr_devices != 1:
err_msg = "the module was lowered for more than 1 device."
elif (_check_module(submodule, disabled_checks=()) or
any(s is not None and not s.is_replicated()
for s in exported.in_shardings + exported.out_shardings)):
err_msg = "the module contains non-replicated sharding annotations."
if err_msg:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices. This is disallowed because: {err_msg}"
)
# Apply in_shardings
args = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
# is valid.
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
else:
return x
callee_type = symtab["main"].type
# TODO: maybe cache multiple calls
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_exported_{exported.fun_name}",
submodule,
dst_symtab=ctx.module_context.symbol_table)
submodule_args: list[ir.Value] = []
# All the platforms for the current lowering must be among the platforms
# for which the callee was lowered.
lowering_platforms = ctx.module_context.platforms
callee_lowering_platform_index: list[int] = []
for platform in lowering_platforms:
if platform in exported.lowering_platforms:
callee_lowering_platform_index.append(
exported.lowering_platforms.index(platform))
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
callee_lowering_platform_index.append(0)
else:
raise ValueError(
f"The exported function '{exported.fun_name}' was lowered for "
f"platforms '{exported.lowering_platforms}' but it is used "
f"on '{lowering_platforms}'.")
if len(exported.lowering_platforms) > 1:
# The exported module takes a platform index argument
if len(lowering_platforms) > 1:
current_platform_idx = ctx.dim_var_values[0]
else:
current_platform_idx = mlir.ir_constant(np.int32(0))
# Compute the rule index based on the current platform
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
callee_platform_idx = hlo.CaseOp([i32_type],
index=current_platform_idx,
num_branches=len(lowering_platforms))
for i in range(len(lowering_platforms)):
branch = callee_platform_idx.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.return_(mlir.ir_constants(
np.int32(callee_lowering_platform_index[i])))
if callee_platform_idx.result.type != callee_type.inputs[0]:
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
callee_platform_idx)
submodule_args.append(callee_platform_idx)
else:
assert len(lowering_platforms) == 1
ordered_effects = exported.ordered_effects
for eff in ordered_effects:
token_in = ctx.tokens_in.get(eff)[0]
submodule_args.append(token_in)
kept_args = [
convert_shape(a, a_aval, exported_in_aval)
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
if i in exported.module_kept_var_idx]
submodule_args = submodule_args + kept_args
call = func_dialect.CallOp(callee_type.results,
ir.FlatSymbolRefAttr.get(fn),
submodule_args)
if ordered_effects:
tokens_out = {eff: (call.results[effect_idx],)
for effect_idx, eff in enumerate(ordered_effects)}
ctx.set_tokens_out(mlir.TokenSet(tokens_out))
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
results = tuple(
convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):],
exported.out_avals, ctx.avals_out))
# Apply out_shardings
results = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
)
return results
mlir.register_lowering(call_exported_p, _call_exported_lowering)
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: HloSharding | None) -> ir.Value:
if x_sharding is None:
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(shape_poly.symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version