mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Refactor the gradient machinery for native serialization
In #15341 we have refactored jax2tf to separate the JAX and TF pieces of the handling of gradients. Now we continue the refactoring and we move the JAX-only pieces from jax2tf.py into jax_export.py. The goal is to collect in jax_export all the pure JAX pieces needed for serialization. This is a pure refactoring, there should be no change in semantics.
This commit is contained in:
parent
053affd173
commit
8ad5b0ef6b
@ -386,7 +386,7 @@ def convert(fun_jax: Callable,
|
||||
lowering_platform = native_serialization_platforms[0]
|
||||
else:
|
||||
lowering_platform = None
|
||||
exported: Optional[jax_export.Exported] = jax_export.serialize_native(
|
||||
exported: Optional[jax_export.Exported] = jax_export.export_native(
|
||||
fun_flat_jax, args_avals_flat,
|
||||
lowering_platform=lowering_platform,
|
||||
strict_checks=native_serialization_strict_checks)
|
||||
@ -394,8 +394,7 @@ def convert(fun_jax: Callable,
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||||
outs_tf, out_avals = run_exported_as_tf(
|
||||
args_avals_flat, args_flat_tf, exported, # type: ignore
|
||||
native_serialization_strict_checks)
|
||||
args_flat_tf, exported) # type: ignore
|
||||
return outs_tf, out_avals
|
||||
else:
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
||||
@ -428,7 +427,6 @@ def convert(fun_jax: Callable,
|
||||
_has_registered_tf_source_path = True
|
||||
|
||||
if with_gradient:
|
||||
|
||||
@tf.custom_gradient
|
||||
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
|
||||
outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf)
|
||||
@ -439,9 +437,6 @@ def convert(fun_jax: Callable,
|
||||
polymorphic_shapes_flat=polymorphic_shapes_flat,
|
||||
out_avals=out_avals,
|
||||
outs_tf=outs_tf,
|
||||
native_serialization=native_serialization,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_strict_checks=native_serialization_strict_checks,
|
||||
exported_primal=exported))
|
||||
|
||||
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
|
||||
@ -562,14 +557,17 @@ def _make_custom_gradient_fn_tf(*,
|
||||
polymorphic_shapes_flat: Sequence[str],
|
||||
out_avals: Sequence[core.ShapedArray],
|
||||
outs_tf: Sequence[TfVal],
|
||||
native_serialization: Union[str, bool],
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_strict_checks: bool,
|
||||
exported_primal: Optional[jax_export.Exported]):
|
||||
"""Prepares the TF function to be used with tf.custom_gradient.
|
||||
|
||||
Args:
|
||||
fun_flat_jax: the flattened JAX primal function
|
||||
args_flat_tf: the TF arguments of the primal function
|
||||
out_avals: the output JAX abstract values of the primal function
|
||||
outs_tf: the TF outputs of the primal function
|
||||
exported_primal: is None for graph serialization, and is the exported
|
||||
primal function for native serialization.
|
||||
"""
|
||||
|
||||
def grad_fn_tf(*out_cts_flat_tf: TfVal,
|
||||
variables=None):
|
||||
if variables:
|
||||
@ -578,63 +576,12 @@ def _make_custom_gradient_fn_tf(*,
|
||||
"This should not happen for first-order differentiation. "
|
||||
f"{variables=}")
|
||||
|
||||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be _DimExpr, not just DimVar
|
||||
for out_aval in out_avals) # type: ignore
|
||||
vjp_polymorphic_shapes = [
|
||||
polymorphic_shapes_flat, out_cts_flat_polymorphic_shapes
|
||||
]
|
||||
|
||||
def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
|
||||
# One may think that we can get the pullback while we are converting
|
||||
# the main function in the first place. That is problematic, because the
|
||||
# pullback may contain captured tracers from the conversion of the
|
||||
# main function. Those tracers will confuse the conversion of the
|
||||
# pullback. So, we construct the vjp anew and we convert it separately.
|
||||
def fun_vjp_jax(*args_and_out_cts_flat_jax):
|
||||
# Takes a flat list of primals and output cotangents
|
||||
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(args_flat_tf)])
|
||||
_, pullback_jax = jax.vjp(fun_flat_jax, *args_flat_jax)
|
||||
return pullback_jax(list(out_cts_flat_jax))
|
||||
return pullback_jax(out_cts_flat_jax)
|
||||
|
||||
if exported_primal is not None:
|
||||
# Native lowering
|
||||
all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals)
|
||||
for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx),
|
||||
exported_primal.in_shardings):
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
all_shardings = all_in_shardings + list(exported_primal.out_shardings)
|
||||
# We cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated
|
||||
specified_shardings = [
|
||||
s for s in all_shardings if not pxla._is_unspecified(s)]
|
||||
if 0 < len(specified_shardings) < len(all_shardings):
|
||||
# There are some specified, but not all
|
||||
in_s = specified_shardings[0] # pjit will enforce that all have same devices
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
all_shardings = [
|
||||
s if not pxla._is_unspecified(s) else replicated_s
|
||||
for s in all_shardings]
|
||||
# Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings
|
||||
vjp_in_args_shardings: Any
|
||||
vjp_in_out_ct_shardings: Any
|
||||
vjp_in_shardings: Any
|
||||
vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings,
|
||||
[len(exported_primal.in_avals)])
|
||||
# pjit front-end does not like all-unspecified
|
||||
if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings):
|
||||
vjp_in_args_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_args_shardings = tuple(vjp_in_args_shardings)
|
||||
if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings):
|
||||
vjp_in_out_ct_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings)
|
||||
|
||||
if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings):
|
||||
vjp_in_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings)
|
||||
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
|
||||
in_shardings=vjp_in_shardings,
|
||||
out_shardings=vjp_in_args_shardings)
|
||||
# TODO: enable higher-order gradients
|
||||
with tf.name_scope("jax2tf_vjp"):
|
||||
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
|
||||
@ -650,14 +597,22 @@ def _make_custom_gradient_fn_tf(*,
|
||||
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
|
||||
|
||||
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, out_avals, outs_tf))
|
||||
in_cts_flat = convert(
|
||||
fun_vjp_jax,
|
||||
with_gradient=False,
|
||||
polymorphic_shapes=vjp_polymorphic_shapes,
|
||||
native_serialization=native_serialization,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_strict_checks=native_serialization_strict_checks
|
||||
)(args_flat_tf, out_cts_fixed_flat_tf)
|
||||
vjp_args_flat_tf = tuple(args_flat_tf) + out_cts_fixed_flat_tf
|
||||
if exported_primal is not None:
|
||||
exported_vjp = exported_primal.vjp()
|
||||
vjp_args_flat_tf = tuple(tf.identity(arg, f"jax2tf_arg_{arg_idx}")
|
||||
for arg_idx, arg in enumerate(vjp_args_flat_tf))
|
||||
in_cts_flat, _ = run_exported_as_tf(vjp_args_flat_tf, exported_vjp)
|
||||
in_cts_flat = tuple(tf.identity(arg, "jax2tf_out") for arg in in_cts_flat)
|
||||
else:
|
||||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be _DimExpr, not just DimVar
|
||||
for out_aval in out_avals) # type: ignore
|
||||
vjp_polymorphic_shapes = tuple(polymorphic_shapes_flat) + out_cts_flat_polymorphic_shapes
|
||||
in_cts_flat = convert(
|
||||
fun_vjp_jax,
|
||||
with_gradient=False,
|
||||
polymorphic_shapes=vjp_polymorphic_shapes,
|
||||
native_serialization=False)(*vjp_args_flat_tf)
|
||||
return in_cts_flat
|
||||
|
||||
return grad_fn_tf
|
||||
@ -692,13 +647,16 @@ def _interpret_fun_jax(
|
||||
return util.unzip2(out_vals)
|
||||
|
||||
|
||||
def run_exported_as_tf(args_avals: Sequence[core.ShapedArray],
|
||||
args_tf: Sequence[TfVal],
|
||||
def run_exported_as_tf(args_tf: Sequence[TfVal],
|
||||
exported: jax_export.Exported,
|
||||
native_serialization_strict_checks: bool,
|
||||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||||
"""Runs the `exported` as an XlaCallModule TF op."""
|
||||
out_shapes = tuple(
|
||||
"""Runs the `exported` as an XlaCallModule TF op.
|
||||
|
||||
Returns:
|
||||
a tuple with the results and the abstract values.
|
||||
"""
|
||||
args_avals = exported.in_avals
|
||||
out_shapes_tf = tuple(
|
||||
tuple(d if type(d) is int else None
|
||||
for d in out_aval.shape)
|
||||
for out_aval in exported.out_avals)
|
||||
@ -715,10 +673,10 @@ def run_exported_as_tf(args_avals: Sequence[core.ShapedArray],
|
||||
call_module_attrs = dict(
|
||||
version=exported.xla_call_module_version,
|
||||
Tout=out_types,
|
||||
Sout=out_shapes)
|
||||
Sout=out_shapes_tf)
|
||||
|
||||
if exported.xla_call_module_version >= 3:
|
||||
if native_serialization_strict_checks:
|
||||
if exported.strict_checks:
|
||||
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
|
||||
else:
|
||||
call_module_attrs["platforms"] = () # No platform checking
|
||||
|
@ -13,13 +13,13 @@
|
||||
# limitations under the License.
|
||||
"""JAX APIs for exporting code for interoperation.
|
||||
|
||||
This module is used with jax2tf, but should have no TensorFlow dependencies.
|
||||
This module is used with jax2tf, but has no TensorFlow dependencies.
|
||||
"""
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -27,6 +27,7 @@ import jax
|
||||
from jax import sharding
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import pjit
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -43,63 +44,88 @@ from jax.experimental.jax2tf import shape_poly
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
# These are the JAX custom call target names that are guaranteed to be stable.
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"ducc_fft", "cu_threefry2x32",
|
||||
# eigh on CPU
|
||||
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
||||
# eigh on GPU
|
||||
"cusolver_syevj", "cusolver_syevd",
|
||||
# eigh on TPU
|
||||
"Eigh",
|
||||
# qr on CPU
|
||||
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
|
||||
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
|
||||
# qr on GPU
|
||||
"cusolver_geqrf", "cublas_geqrf_batched",
|
||||
"cusolver_geqrf", "cusolver_orgqr",
|
||||
# qr and svd on TPU
|
||||
"Qr", "ProductOfElementaryHouseholderReflectors",
|
||||
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
|
||||
# # lu on CPU
|
||||
# "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf",
|
||||
# # lu on GPU
|
||||
# "cublas_getrf_batched", "cusolver_getrf",
|
||||
# "hipblas_getrf_batched", "hipsolver_getrf",
|
||||
# lu on TPU
|
||||
"LuDecomposition",
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Exported:
|
||||
"""Represents a lowered and serialized JAX module."""
|
||||
"""A lowered and serialized JAX function.
|
||||
|
||||
Currently this works only for functions that take a flat argument list
|
||||
and return a tuple of results. (No pytree support yet.)
|
||||
|
||||
Attributes:
|
||||
in_avals: the input abstract values. May contain dimension expressions in
|
||||
the shapes.
|
||||
out_avals: the output abstract values. May contain dimension expressions in
|
||||
the shapes, with dimension variables among those in `in_avals`.
|
||||
in_shardings: the input shardings. Only for the `module_kept_var_idx`.
|
||||
out_shardings: the output shardings.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
|
||||
|
||||
mlir_module_serialized: the serialized lowered VHLO module.
|
||||
mlir_module_version: a version number for the serialized module.
|
||||
The following version numbers are valid:
|
||||
4 - mlir_module_serialized is a portable artifact.
|
||||
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
||||
must be passed to the module. The other arguments have been dropped
|
||||
because they are not used. Same length as `in_shardings`.
|
||||
strict_checks: whether the module was serialized with the following safety
|
||||
checking: (A) the lowered computation can only be executed on a platform
|
||||
for which it was lowered; (B) the serialized computation contains only
|
||||
custom calls with targets that are guaranteed to be stable, (more to come).
|
||||
_get_vjp: an optional function that takes the current exported function and
|
||||
returns the exported VJP function.
|
||||
The VJP function takes a flat list of arguments,
|
||||
starting with the primal arguments and followed by a cotangent argument
|
||||
for each primal output. It returns a tuple with the cotangents
|
||||
corresponding to the primal inputs.
|
||||
"""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
out_avals: Sequence[core.ShapedArray]
|
||||
# The in_shardings reflect only the module_kept_var_idx
|
||||
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
lowering_platform: str
|
||||
strict_checks: bool
|
||||
|
||||
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
|
||||
mlir_module_serialized: bytes
|
||||
xla_call_module_version: int
|
||||
module_kept_var_idx: Sequence[int]
|
||||
|
||||
mlir_module: mlir.ir.Module
|
||||
mlir_module_serialized: bytes # VHLO bytecode format
|
||||
xla_call_module_version: int # Follows the versions of XlaCallModule
|
||||
module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the
|
||||
# lowering. Same length as `in_shardings`.
|
||||
_get_vjp: Optional[Callable[["Exported"], "Exported"]]
|
||||
|
||||
@property
|
||||
def mlir_module(self) -> mlir.ir.Module:
|
||||
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
||||
|
||||
def vjp(self) -> "Exported":
|
||||
"""Gets the exported VJP.
|
||||
|
||||
Returns None if not available, which can happen if the Exported has been
|
||||
loaded from an external format, without a VJP."""
|
||||
if self._get_vjp is None:
|
||||
raise ValueError("No VJP is available")
|
||||
return self._get_vjp(self)
|
||||
|
||||
|
||||
def default_jax_backend() -> str:
|
||||
def _default_jax_backend() -> str:
|
||||
# Canonicalize to turn into CUDA or ROCM
|
||||
return xb.canonicalize_platform(jax.default_backend())
|
||||
|
||||
|
||||
def serialize_native(fun_jax: Callable,
|
||||
args_avals: Sequence[core.ShapedArray], *,
|
||||
lowering_platform: Optional[str],
|
||||
strict_checks: bool) -> Exported:
|
||||
def export_native(fun_jax: Callable,
|
||||
args_avals: Sequence[core.ShapedArray], *,
|
||||
lowering_platform: Optional[str],
|
||||
strict_checks: bool) -> Exported:
|
||||
"""Exports native serialization for a JAX function.
|
||||
|
||||
At the moment works only for JAX functions that take a flat list of arguments
|
||||
and return a flat list of results.
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize
|
||||
args_avals: the abstract values at which to lower.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
|
||||
strict_checks: whether to do strict safety checks. See Exported.strict_checks
|
||||
for more details.
|
||||
"""
|
||||
arg_specs_jax = [
|
||||
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
|
||||
for aval in args_avals
|
||||
@ -109,62 +135,19 @@ def serialize_native(fun_jax: Callable,
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. In that case we raise
|
||||
# an error if the lowered function contains non-replicated sharding annotations.
|
||||
fun_jax_lower = jax.jit(fun_jax).lower
|
||||
wrapped_fun_jax = jax.jit(fun_jax)
|
||||
allow_non_replicated_sharding = False
|
||||
else:
|
||||
# If we have a pjit or pmap already we do not wrap with another, and we
|
||||
# allow shardings.
|
||||
fun_jax_lower = fun_jax.lower
|
||||
wrapped_fun_jax = fun_jax # type: ignore
|
||||
allow_non_replicated_sharding = True
|
||||
|
||||
lowered = fun_jax_lower(
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*arg_specs_jax,
|
||||
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
|
||||
|
||||
if not isinstance(lowered, pxla.MeshComputation):
|
||||
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
|
||||
|
||||
# Check that we do not see new compile_args. When we add a compile_args it is
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = ["backend", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"host_callbacks", "keepalive", "pmap_nreps", "committed", "device_assignment"]
|
||||
for compile_arg in lowered.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
||||
# We have not implemented support for some of the compile_args.
|
||||
not_implemented_msgs = []
|
||||
for compile_arg, check_value, err_msg in (
|
||||
("spmd_lowering", lambda v: v, "True"),
|
||||
("auto_spmd_lowering", lambda v: not v, "False"),
|
||||
# tuple_args is a compilation flag, does not affect lowering.
|
||||
("tuple_args", lambda v: True, "N/A"),
|
||||
# Used for debug(ordered=True), changes the calling convention, but will
|
||||
# also set keepalive to non-empty.
|
||||
("ordered_effects", lambda v: not v, "empty"),
|
||||
# unordered_effects do not change the calling convention. Those from
|
||||
# jax.debug will also result in keepalive being non-empty and unsupported
|
||||
# custom calls. The CallTfEffect is an exception, but we want to allow
|
||||
# that one.
|
||||
("unordered_effects", lambda v: True, "N/A"),
|
||||
# used for TPU jax.debug, send/recv. Not supported yet.
|
||||
("host_callbacks", lambda v: not v, "empty"),
|
||||
# used on all platforms for callbacks. Not supported yet.
|
||||
("keepalive", lambda v: not v, "empty"),
|
||||
("pmap_nreps", lambda v: v == 1, "1"),
|
||||
):
|
||||
if compile_arg in lowered.compile_args:
|
||||
if not check_value(lowered.compile_args[compile_arg]):
|
||||
not_implemented_msgs.append(
|
||||
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
|
||||
if not_implemented_msgs:
|
||||
raise NotImplementedError(
|
||||
"serialization error, unimplemented lowered.compile_args:\n" +
|
||||
"\n".join(not_implemented_msgs))
|
||||
_check_lowered(lowered)
|
||||
|
||||
mlir_module = lowered.stablehlo()
|
||||
if "kept_var_idx" in lowered.compile_args:
|
||||
@ -176,7 +159,7 @@ def serialize_native(fun_jax: Callable,
|
||||
if not all(core.is_constant_shape(a.shape) for a in args_avals):
|
||||
# All arguments are kept if we have dimension variables.
|
||||
assert len(module_kept_var_idx) == len(args_avals)
|
||||
mlir_module = add_dim_arg_computation(mlir_module, args_avals)
|
||||
mlir_module = _add_dim_arg_computation(mlir_module, args_avals)
|
||||
|
||||
xla_call_module_version = 4
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
@ -192,8 +175,6 @@ def serialize_native(fun_jax: Callable,
|
||||
out_avals = lowered.compile_args["shards"].out_sharded_avals
|
||||
else:
|
||||
out_avals = lowered.compile_args["out_avals"]
|
||||
if lowered.compile_args["host_callbacks"]:
|
||||
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
|
||||
|
||||
# Log and then check the module.
|
||||
if logging.vlog_is_on(3):
|
||||
@ -201,24 +182,25 @@ def serialize_native(fun_jax: Callable,
|
||||
logmsg = f"version={xla_call_module_version} lowering_platform={lowering_platform}"
|
||||
logging.vlog(3, "Lowered JAX module: %s\n%s", logmsg, mlir_module_text)
|
||||
|
||||
check_module(mlir_module,
|
||||
allow_non_replicated_sharding=allow_non_replicated_sharding,
|
||||
allow_all_custom_calls=not strict_checks)
|
||||
_check_module(mlir_module,
|
||||
allow_non_replicated_sharding=allow_non_replicated_sharding,
|
||||
allow_all_custom_calls=not strict_checks)
|
||||
|
||||
return Exported(
|
||||
in_avals=args_avals,
|
||||
out_avals=out_avals,
|
||||
in_shardings=lowered.compile_args["in_shardings"],
|
||||
out_shardings=lowered.compile_args["out_shardings"],
|
||||
lowering_platform=lowering_platform or default_jax_backend(),
|
||||
mlir_module=mlir_module,
|
||||
lowering_platform=lowering_platform or _default_jax_backend(),
|
||||
strict_checks=strict_checks,
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
xla_call_module_version=xla_call_module_version)
|
||||
xla_call_module_version=xla_call_module_version,
|
||||
_get_vjp=lambda exported: _export_native_vjp(wrapped_fun_jax, exported))
|
||||
|
||||
|
||||
def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
|
||||
def _add_dim_arg_computation(module: mlir.ir.Module,
|
||||
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
|
||||
"""Wraps the lowered module with a new "main" that computes the dim args.
|
||||
|
||||
JAX lowering in presence of shape polymorphism produces a `module` that
|
||||
@ -294,7 +276,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
primitive=None, avals_in=args_avals, avals_out=None,
|
||||
tokens_in=mlir.TokenSet(), tokens_out=None)
|
||||
dim_args = compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
|
||||
dim_args = _compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
|
||||
orig_input_types[:len(dim_vars)])
|
||||
# The first arguments are the dimension variable
|
||||
orig_main_args.extend(dim_args)
|
||||
@ -308,7 +290,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
return new_module
|
||||
|
||||
|
||||
def compute_dim_args(
|
||||
def _compute_dim_args(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
args_avals: Sequence[core.ShapedArray],
|
||||
array_args: Sequence[mlir.ir.Value],
|
||||
@ -337,9 +319,86 @@ def compute_dim_args(
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def check_module(mod: mlir.ir.Module, *,
|
||||
allow_non_replicated_sharding: bool,
|
||||
allow_all_custom_calls: bool):
|
||||
def _check_lowered(lowered) -> None:
|
||||
if not isinstance(lowered, pxla.MeshComputation):
|
||||
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
|
||||
|
||||
if lowered.compile_args["host_callbacks"] or lowered.compile_args["keepalive"]:
|
||||
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
|
||||
# Check that we do not see new compile_args. When we add a compile_args it is
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = ["backend", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment"]
|
||||
for compile_arg in lowered.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
||||
# We have not implemented support for some of the compile_args.
|
||||
not_implemented_msgs = []
|
||||
for compile_arg, check_value, err_msg in (
|
||||
("spmd_lowering", lambda v: v, "True"),
|
||||
("auto_spmd_lowering", lambda v: not v, "False"),
|
||||
# tuple_args is a compilation flag, does not affect lowering.
|
||||
("tuple_args", lambda v: True, "N/A"),
|
||||
# Used for debug(ordered=True), changes the calling convention, but will
|
||||
# also set keepalive to non-empty.
|
||||
("ordered_effects", lambda v: not v, "empty"),
|
||||
# unordered_effects do not change the calling convention. Those from
|
||||
# jax.debug will also result in keepalive being non-empty and unsupported
|
||||
# custom calls. The CallTfEffect is an exception, but we want to allow
|
||||
# that one.
|
||||
("unordered_effects", lambda v: True, "N/A"),
|
||||
# used for TPU jax.debug, send/recv. Not supported yet.
|
||||
("host_callbacks", lambda v: not v, "empty"),
|
||||
# used on all platforms for callbacks. Not supported yet.
|
||||
("keepalive", lambda v: not v, "empty"),
|
||||
("pmap_nreps", lambda v: v == 1, "1"),
|
||||
):
|
||||
if compile_arg in lowered.compile_args:
|
||||
if not check_value(lowered.compile_args[compile_arg]):
|
||||
not_implemented_msgs.append(
|
||||
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
|
||||
if not_implemented_msgs:
|
||||
raise NotImplementedError(
|
||||
"serialization error, unimplemented lowered.compile_args:\n" +
|
||||
"\n".join(not_implemented_msgs))
|
||||
|
||||
# These are the JAX custom call target names that are guaranteed to be stable.
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"ducc_fft", "cu_threefry2x32",
|
||||
# eigh on CPU
|
||||
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
||||
# eigh on GPU
|
||||
"cusolver_syevj", "cusolver_syevd",
|
||||
# eigh on TPU
|
||||
"Eigh",
|
||||
# qr on CPU
|
||||
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
|
||||
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
|
||||
# qr on GPU
|
||||
"cusolver_geqrf", "cublas_geqrf_batched",
|
||||
"cusolver_geqrf", "cusolver_orgqr",
|
||||
# qr and svd on TPU
|
||||
"Qr", "ProductOfElementaryHouseholderReflectors",
|
||||
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
|
||||
# # lu on CPU
|
||||
# "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf",
|
||||
# # lu on GPU
|
||||
# "cublas_getrf_batched", "cusolver_getrf",
|
||||
# "hipblas_getrf_batched", "hipsolver_getrf",
|
||||
# lu on TPU
|
||||
"LuDecomposition",
|
||||
]
|
||||
|
||||
def _check_module(mod: mlir.ir.Module, *,
|
||||
allow_non_replicated_sharding: bool,
|
||||
allow_all_custom_calls: bool):
|
||||
"""Run a number of checks on the module.
|
||||
|
||||
Args:
|
||||
@ -393,3 +452,55 @@ def check_module(mod: mlir.ir.Module, *,
|
||||
f"{disallowed_custom_call_ops_str}.\n"
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
|
||||
raise ValueError(msg)
|
||||
|
||||
def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
# Export the VJP of `fun_flat_jax`
|
||||
|
||||
def fun_vjp_jax(*args_and_out_cts_flat_jax):
|
||||
# Takes a flat list of primals and output cotangents
|
||||
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(primal.in_avals)])
|
||||
_, pullback_jax = jax.vjp(primal_fun_jax, *args_flat_jax)
|
||||
return pullback_jax(out_cts_flat_jax)
|
||||
|
||||
vjp_in_avals = list(
|
||||
itertools.chain(primal.in_avals,
|
||||
map(lambda a: a.at_least_vspace(), primal.out_avals)))
|
||||
|
||||
# Expand in_shardings to all in_avals even not kept ones.
|
||||
all_in_shardings = [pxla._UNSPECIFIED] * len(primal.in_avals)
|
||||
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
|
||||
primal.in_shardings):
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
all_shardings = all_in_shardings + list(primal.out_shardings)
|
||||
# Cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated.
|
||||
specified_shardings = [
|
||||
s for s in all_shardings if not pxla._is_unspecified(s)]
|
||||
|
||||
vjp_in_shardings: Any # The primal inputs followed by output cotangents
|
||||
vjp_out_shardings: Any # The primal output cotangents
|
||||
if 0 == len(specified_shardings):
|
||||
vjp_in_shardings = pxla._UNSPECIFIED
|
||||
vjp_out_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
if len(specified_shardings) < len(all_shardings):
|
||||
# There are some specified, but not all; pjit front-end does not liwk
|
||||
in_s = specified_shardings[0] # pjit will enforce that all have same devices
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
all_shardings = [
|
||||
s if not pxla._is_unspecified(s) else replicated_s
|
||||
for s in all_shardings]
|
||||
|
||||
vjp_in_shardings = tuple(all_shardings)
|
||||
vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)])
|
||||
if all(pxla._is_unspecified(s) for s in vjp_out_shardings):
|
||||
vjp_out_shardings = pxla._UNSPECIFIED
|
||||
|
||||
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
|
||||
in_shardings=vjp_in_shardings,
|
||||
out_shardings=vjp_out_shardings)
|
||||
|
||||
return export_native(fun_vjp_jax, vjp_in_avals,
|
||||
lowering_platform=primal.lowering_platform,
|
||||
strict_checks=primal.strict_checks)
|
||||
|
@ -203,7 +203,7 @@ class CompatTest(jtu.JaxTestCase):
|
||||
res_from_jax = tuple(np.array(a) for a in res_from_jax)
|
||||
|
||||
# Use the native exporter, to make sure we get the proper serialized module.
|
||||
exported = jax2tf.jax_export.serialize_native(
|
||||
exported = jax2tf.jax_export.export_native(
|
||||
jax.jit(func),
|
||||
[core.ShapedArray(a.shape, a.dtype) for a in data.inputs],
|
||||
lowering_platform=default_jax_backend(),
|
||||
|
@ -1531,7 +1531,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
stack.enter_context(mesh)
|
||||
# Run the JAX native version, to check it works, and to fill caches.
|
||||
_ = func_to_convert(*args)
|
||||
exported = jax_export.serialize_native(
|
||||
exported = jax_export.export_native(
|
||||
func_to_convert,
|
||||
[core.ShapedArray(a.shape, a.dtype) for a in args],
|
||||
lowering_platform='tpu',
|
||||
@ -1608,7 +1608,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
return jnp.sin(x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"keepalive must be empty"):
|
||||
"serialization of host_callbacks is not yet implemented"):
|
||||
jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.))
|
||||
|
||||
def f_ordered_jax(x):
|
||||
@ -1616,7 +1616,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
return jnp.sin(x)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"keepalive must be empty"):
|
||||
"serialization of host_callbacks is not yet implemented"):
|
||||
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
|
||||
|
||||
def test_tuple_args(self):
|
||||
|
@ -372,7 +372,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
for out_shardings in ("missing", None, "P")
|
||||
)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_grad_pjit(self, in_shardings="missing", out_shardings="None"):
|
||||
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
|
||||
|
@ -133,7 +133,7 @@ SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
|
||||
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
|
||||
if error_type == SpecErrorType.input and specs is None:
|
||||
raise TypeError(
|
||||
f"shard_map in_specs argument must be a pytree of "
|
||||
"shard_map in_specs argument must be a pytree of "
|
||||
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
|
||||
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
|
||||
"where `P = jax.sharding.PartitionSpec`?")
|
||||
|
Loading…
x
Reference in New Issue
Block a user