[jax2tf] Refactor the gradient machinery for native serialization

In #15341 we have refactored jax2tf to separate the JAX and TF pieces
of the handling of gradients. Now we continue the refactoring and
we move the JAX-only pieces from jax2tf.py into jax_export.py. The
goal is to collect in jax_export all the pure JAX pieces needed
for serialization.

This is a pure refactoring, there should be no change in semantics.
This commit is contained in:
George Necula 2023-04-04 13:23:43 +02:00
parent 053affd173
commit 8ad5b0ef6b
6 changed files with 264 additions and 195 deletions

View File

@ -386,7 +386,7 @@ def convert(fun_jax: Callable,
lowering_platform = native_serialization_platforms[0]
else:
lowering_platform = None
exported: Optional[jax_export.Exported] = jax_export.serialize_native(
exported: Optional[jax_export.Exported] = jax_export.export_native(
fun_flat_jax, args_avals_flat,
lowering_platform=lowering_platform,
strict_checks=native_serialization_strict_checks)
@ -394,8 +394,7 @@ def convert(fun_jax: Callable,
args_flat_tf: Sequence[TfVal]
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
outs_tf, out_avals = run_exported_as_tf(
args_avals_flat, args_flat_tf, exported, # type: ignore
native_serialization_strict_checks)
args_flat_tf, exported) # type: ignore
return outs_tf, out_avals
else:
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
@ -428,7 +427,6 @@ def convert(fun_jax: Callable,
_has_registered_tf_source_path = True
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf)
@ -439,9 +437,6 @@ def convert(fun_jax: Callable,
polymorphic_shapes_flat=polymorphic_shapes_flat,
out_avals=out_avals,
outs_tf=outs_tf,
native_serialization=native_serialization,
native_serialization_platforms=native_serialization_platforms,
native_serialization_strict_checks=native_serialization_strict_checks,
exported_primal=exported))
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
@ -562,14 +557,17 @@ def _make_custom_gradient_fn_tf(*,
polymorphic_shapes_flat: Sequence[str],
out_avals: Sequence[core.ShapedArray],
outs_tf: Sequence[TfVal],
native_serialization: Union[str, bool],
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool,
exported_primal: Optional[jax_export.Exported]):
"""Prepares the TF function to be used with tf.custom_gradient.
Args:
fun_flat_jax: the flattened JAX primal function
args_flat_tf: the TF arguments of the primal function
out_avals: the output JAX abstract values of the primal function
outs_tf: the TF outputs of the primal function
exported_primal: is None for graph serialization, and is the exported
primal function for native serialization.
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
@ -578,63 +576,12 @@ def _make_custom_gradient_fn_tf(*,
"This should not happen for first-order differentiation. "
f"{variables=}")
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be _DimExpr, not just DimVar
for out_aval in out_avals) # type: ignore
vjp_polymorphic_shapes = [
polymorphic_shapes_flat, out_cts_flat_polymorphic_shapes
]
def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
# One may think that we can get the pullback while we are converting
# the main function in the first place. That is problematic, because the
# pullback may contain captured tracers from the conversion of the
# main function. Those tracers will confuse the conversion of the
# pullback. So, we construct the vjp anew and we convert it separately.
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(args_flat_tf)])
_, pullback_jax = jax.vjp(fun_flat_jax, *args_flat_jax)
return pullback_jax(list(out_cts_flat_jax))
return pullback_jax(out_cts_flat_jax)
if exported_primal is not None:
# Native lowering
all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals)
for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx),
exported_primal.in_shardings):
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(exported_primal.out_shardings)
# We cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated
specified_shardings = [
s for s in all_shardings if not pxla._is_unspecified(s)]
if 0 < len(specified_shardings) < len(all_shardings):
# There are some specified, but not all
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not pxla._is_unspecified(s) else replicated_s
for s in all_shardings]
# Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings
vjp_in_args_shardings: Any
vjp_in_out_ct_shardings: Any
vjp_in_shardings: Any
vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings,
[len(exported_primal.in_avals)])
# pjit front-end does not like all-unspecified
if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings):
vjp_in_args_shardings = pxla._UNSPECIFIED
else:
vjp_in_args_shardings = tuple(vjp_in_args_shardings)
if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings):
vjp_in_out_ct_shardings = pxla._UNSPECIFIED
else:
vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings)
if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings):
vjp_in_shardings = pxla._UNSPECIFIED
else:
vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings)
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_in_args_shardings)
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
@ -650,14 +597,22 @@ def _make_custom_gradient_fn_tf(*,
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, out_avals, outs_tf))
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes,
native_serialization=native_serialization,
native_serialization_platforms=native_serialization_platforms,
native_serialization_strict_checks=native_serialization_strict_checks
)(args_flat_tf, out_cts_fixed_flat_tf)
vjp_args_flat_tf = tuple(args_flat_tf) + out_cts_fixed_flat_tf
if exported_primal is not None:
exported_vjp = exported_primal.vjp()
vjp_args_flat_tf = tuple(tf.identity(arg, f"jax2tf_arg_{arg_idx}")
for arg_idx, arg in enumerate(vjp_args_flat_tf))
in_cts_flat, _ = run_exported_as_tf(vjp_args_flat_tf, exported_vjp)
in_cts_flat = tuple(tf.identity(arg, "jax2tf_out") for arg in in_cts_flat)
else:
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be _DimExpr, not just DimVar
for out_aval in out_avals) # type: ignore
vjp_polymorphic_shapes = tuple(polymorphic_shapes_flat) + out_cts_flat_polymorphic_shapes
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes,
native_serialization=False)(*vjp_args_flat_tf)
return in_cts_flat
return grad_fn_tf
@ -692,13 +647,16 @@ def _interpret_fun_jax(
return util.unzip2(out_vals)
def run_exported_as_tf(args_avals: Sequence[core.ShapedArray],
args_tf: Sequence[TfVal],
def run_exported_as_tf(args_tf: Sequence[TfVal],
exported: jax_export.Exported,
native_serialization_strict_checks: bool,
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
"""Runs the `exported` as an XlaCallModule TF op."""
out_shapes = tuple(
"""Runs the `exported` as an XlaCallModule TF op.
Returns:
a tuple with the results and the abstract values.
"""
args_avals = exported.in_avals
out_shapes_tf = tuple(
tuple(d if type(d) is int else None
for d in out_aval.shape)
for out_aval in exported.out_avals)
@ -715,10 +673,10 @@ def run_exported_as_tf(args_avals: Sequence[core.ShapedArray],
call_module_attrs = dict(
version=exported.xla_call_module_version,
Tout=out_types,
Sout=out_shapes)
Sout=out_shapes_tf)
if exported.xla_call_module_version >= 3:
if native_serialization_strict_checks:
if exported.strict_checks:
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
else:
call_module_attrs["platforms"] = () # No platform checking

View File

@ -13,13 +13,13 @@
# limitations under the License.
"""JAX APIs for exporting code for interoperation.
This module is used with jax2tf, but should have no TensorFlow dependencies.
This module is used with jax2tf, but has no TensorFlow dependencies.
"""
import dataclasses
import functools
import itertools
import re
from typing import Callable, List, Optional, Sequence, Union
from typing import Any, Callable, List, Optional, Sequence, Union
from absl import logging
@ -27,6 +27,7 @@ import jax
from jax import sharding
from jax._src import core
from jax._src import pjit
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -43,63 +44,88 @@ from jax.experimental.jax2tf import shape_poly
map = util.safe_map
zip = util.safe_zip
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"ducc_fft", "cu_threefry2x32",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_geqrf", "cusolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
# "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf",
# # lu on GPU
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
# lu on TPU
"LuDecomposition",
]
@dataclasses.dataclass
class Exported:
"""Represents a lowered and serialized JAX module."""
"""A lowered and serialized JAX function.
Currently this works only for functions that take a flat argument list
and return a tuple of results. (No pytree support yet.)
Attributes:
in_avals: the input abstract values. May contain dimension expressions in
the shapes.
out_avals: the output abstract values. May contain dimension expressions in
the shapes, with dimension variables among those in `in_avals`.
in_shardings: the input shardings. Only for the `module_kept_var_idx`.
out_shardings: the output shardings.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
mlir_module_serialized: the serialized lowered VHLO module.
mlir_module_version: a version number for the serialized module.
The following version numbers are valid:
4 - mlir_module_serialized is a portable artifact.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used. Same length as `in_shardings`.
strict_checks: whether the module was serialized with the following safety
checking: (A) the lowered computation can only be executed on a platform
for which it was lowered; (B) the serialized computation contains only
custom calls with targets that are guaranteed to be stable, (more to come).
_get_vjp: an optional function that takes the current exported function and
returns the exported VJP function.
The VJP function takes a flat list of arguments,
starting with the primal arguments and followed by a cotangent argument
for each primal output. It returns a tuple with the cotangents
corresponding to the primal inputs.
"""
in_avals: Sequence[core.ShapedArray]
out_avals: Sequence[core.ShapedArray]
# The in_shardings reflect only the module_kept_var_idx
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
lowering_platform: str
strict_checks: bool
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
mlir_module_serialized: bytes
xla_call_module_version: int
module_kept_var_idx: Sequence[int]
mlir_module: mlir.ir.Module
mlir_module_serialized: bytes # VHLO bytecode format
xla_call_module_version: int # Follows the versions of XlaCallModule
module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the
# lowering. Same length as `in_shardings`.
_get_vjp: Optional[Callable[["Exported"], "Exported"]]
@property
def mlir_module(self) -> mlir.ir.Module:
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def vjp(self) -> "Exported":
"""Gets the exported VJP.
Returns None if not available, which can happen if the Exported has been
loaded from an external format, without a VJP."""
if self._get_vjp is None:
raise ValueError("No VJP is available")
return self._get_vjp(self)
def default_jax_backend() -> str:
def _default_jax_backend() -> str:
# Canonicalize to turn into CUDA or ROCM
return xb.canonicalize_platform(jax.default_backend())
def serialize_native(fun_jax: Callable,
args_avals: Sequence[core.ShapedArray], *,
lowering_platform: Optional[str],
strict_checks: bool) -> Exported:
def export_native(fun_jax: Callable,
args_avals: Sequence[core.ShapedArray], *,
lowering_platform: Optional[str],
strict_checks: bool) -> Exported:
"""Exports native serialization for a JAX function.
At the moment works only for JAX functions that take a flat list of arguments
and return a flat list of results.
Args:
fun_jax: the function to lower and serialize
args_avals: the abstract values at which to lower.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
strict_checks: whether to do strict safety checks. See Exported.strict_checks
for more details.
"""
arg_specs_jax = [
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
for aval in args_avals
@ -109,62 +135,19 @@ def serialize_native(fun_jax: Callable,
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
fun_jax_lower = jax.jit(fun_jax).lower
wrapped_fun_jax = jax.jit(fun_jax)
allow_non_replicated_sharding = False
else:
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
fun_jax_lower = fun_jax.lower
wrapped_fun_jax = fun_jax # type: ignore
allow_non_replicated_sharding = True
lowered = fun_jax_lower(
lowered = wrapped_fun_jax.lower(
*arg_specs_jax,
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
if not isinstance(lowered, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = ["backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"host_callbacks", "keepalive", "pmap_nreps", "committed", "device_assignment"]
for compile_arg in lowered.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
# We have not implemented support for some of the compile_args.
not_implemented_msgs = []
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
# tuple_args is a compilation flag, does not affect lowering.
("tuple_args", lambda v: True, "N/A"),
# Used for debug(ordered=True), changes the calling convention, but will
# also set keepalive to non-empty.
("ordered_effects", lambda v: not v, "empty"),
# unordered_effects do not change the calling convention. Those from
# jax.debug will also result in keepalive being non-empty and unsupported
# custom calls. The CallTfEffect is an exception, but we want to allow
# that one.
("unordered_effects", lambda v: True, "N/A"),
# used for TPU jax.debug, send/recv. Not supported yet.
("host_callbacks", lambda v: not v, "empty"),
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
):
if compile_arg in lowered.compile_args:
if not check_value(lowered.compile_args[compile_arg]):
not_implemented_msgs.append(
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
if not_implemented_msgs:
raise NotImplementedError(
"serialization error, unimplemented lowered.compile_args:\n" +
"\n".join(not_implemented_msgs))
_check_lowered(lowered)
mlir_module = lowered.stablehlo()
if "kept_var_idx" in lowered.compile_args:
@ -176,7 +159,7 @@ def serialize_native(fun_jax: Callable,
if not all(core.is_constant_shape(a.shape) for a in args_avals):
# All arguments are kept if we have dimension variables.
assert len(module_kept_var_idx) == len(args_avals)
mlir_module = add_dim_arg_computation(mlir_module, args_avals)
mlir_module = _add_dim_arg_computation(mlir_module, args_avals)
xla_call_module_version = 4
mlir_str = mlir.module_to_bytecode(mlir_module)
@ -192,8 +175,6 @@ def serialize_native(fun_jax: Callable,
out_avals = lowered.compile_args["shards"].out_sharded_avals
else:
out_avals = lowered.compile_args["out_avals"]
if lowered.compile_args["host_callbacks"]:
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
# Log and then check the module.
if logging.vlog_is_on(3):
@ -201,24 +182,25 @@ def serialize_native(fun_jax: Callable,
logmsg = f"version={xla_call_module_version} lowering_platform={lowering_platform}"
logging.vlog(3, "Lowered JAX module: %s\n%s", logmsg, mlir_module_text)
check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding,
allow_all_custom_calls=not strict_checks)
_check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding,
allow_all_custom_calls=not strict_checks)
return Exported(
in_avals=args_avals,
out_avals=out_avals,
in_shardings=lowered.compile_args["in_shardings"],
out_shardings=lowered.compile_args["out_shardings"],
lowering_platform=lowering_platform or default_jax_backend(),
mlir_module=mlir_module,
lowering_platform=lowering_platform or _default_jax_backend(),
strict_checks=strict_checks,
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
xla_call_module_version=xla_call_module_version)
xla_call_module_version=xla_call_module_version,
_get_vjp=lambda exported: _export_native_vjp(wrapped_fun_jax, exported))
def add_dim_arg_computation(module: mlir.ir.Module,
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
def _add_dim_arg_computation(module: mlir.ir.Module,
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
"""Wraps the lowered module with a new "main" that computes the dim args.
JAX lowering in presence of shape polymorphism produces a `module` that
@ -294,7 +276,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
dim_args = compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
dim_args = _compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
orig_input_types[:len(dim_vars)])
# The first arguments are the dimension variable
orig_main_args.extend(dim_args)
@ -308,7 +290,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
return new_module
def compute_dim_args(
def _compute_dim_args(
ctx: mlir.LoweringRuleContext,
args_avals: Sequence[core.ShapedArray],
array_args: Sequence[mlir.ir.Value],
@ -337,9 +319,86 @@ def compute_dim_args(
return tuple(res)
def check_module(mod: mlir.ir.Module, *,
allow_non_replicated_sharding: bool,
allow_all_custom_calls: bool):
def _check_lowered(lowered) -> None:
if not isinstance(lowered, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowered}")
if lowered.compile_args["host_callbacks"] or lowered.compile_args["keepalive"]:
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
# Check that we do not see new compile_args. When we add a compile_args it is
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = ["backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment"]
for compile_arg in lowered.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
# We have not implemented support for some of the compile_args.
not_implemented_msgs = []
for compile_arg, check_value, err_msg in (
("spmd_lowering", lambda v: v, "True"),
("auto_spmd_lowering", lambda v: not v, "False"),
# tuple_args is a compilation flag, does not affect lowering.
("tuple_args", lambda v: True, "N/A"),
# Used for debug(ordered=True), changes the calling convention, but will
# also set keepalive to non-empty.
("ordered_effects", lambda v: not v, "empty"),
# unordered_effects do not change the calling convention. Those from
# jax.debug will also result in keepalive being non-empty and unsupported
# custom calls. The CallTfEffect is an exception, but we want to allow
# that one.
("unordered_effects", lambda v: True, "N/A"),
# used for TPU jax.debug, send/recv. Not supported yet.
("host_callbacks", lambda v: not v, "empty"),
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
):
if compile_arg in lowered.compile_args:
if not check_value(lowered.compile_args[compile_arg]):
not_implemented_msgs.append(
f"{compile_arg} must be {err_msg} and it is {lowered.compile_args[compile_arg]}")
if not_implemented_msgs:
raise NotImplementedError(
"serialization error, unimplemented lowered.compile_args:\n" +
"\n".join(not_implemented_msgs))
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"ducc_fft", "cu_threefry2x32",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_geqrf", "cusolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
# "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf",
# # lu on GPU
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
# lu on TPU
"LuDecomposition",
]
def _check_module(mod: mlir.ir.Module, *,
allow_non_replicated_sharding: bool,
allow_all_custom_calls: bool):
"""Run a number of checks on the module.
Args:
@ -393,3 +452,55 @@ def check_module(mod: mlir.ir.Module, *,
f"{disallowed_custom_call_ops_str}.\n"
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)
def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
# Export the VJP of `fun_flat_jax`
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(primal.in_avals)])
_, pullback_jax = jax.vjp(primal_fun_jax, *args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = list(
itertools.chain(primal.in_avals,
map(lambda a: a.at_least_vspace(), primal.out_avals)))
# Expand in_shardings to all in_avals even not kept ones.
all_in_shardings = [pxla._UNSPECIFIED] * len(primal.in_avals)
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
primal.in_shardings):
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(primal.out_shardings)
# Cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated.
specified_shardings = [
s for s in all_shardings if not pxla._is_unspecified(s)]
vjp_in_shardings: Any # The primal inputs followed by output cotangents
vjp_out_shardings: Any # The primal output cotangents
if 0 == len(specified_shardings):
vjp_in_shardings = pxla._UNSPECIFIED
vjp_out_shardings = pxla._UNSPECIFIED
else:
if len(specified_shardings) < len(all_shardings):
# There are some specified, but not all; pjit front-end does not liwk
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not pxla._is_unspecified(s) else replicated_s
for s in all_shardings]
vjp_in_shardings = tuple(all_shardings)
vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)])
if all(pxla._is_unspecified(s) for s in vjp_out_shardings):
vjp_out_shardings = pxla._UNSPECIFIED
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_out_shardings)
return export_native(fun_vjp_jax, vjp_in_avals,
lowering_platform=primal.lowering_platform,
strict_checks=primal.strict_checks)

View File

@ -203,7 +203,7 @@ class CompatTest(jtu.JaxTestCase):
res_from_jax = tuple(np.array(a) for a in res_from_jax)
# Use the native exporter, to make sure we get the proper serialized module.
exported = jax2tf.jax_export.serialize_native(
exported = jax2tf.jax_export.export_native(
jax.jit(func),
[core.ShapedArray(a.shape, a.dtype) for a in data.inputs],
lowering_platform=default_jax_backend(),

View File

@ -1531,7 +1531,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
stack.enter_context(mesh)
# Run the JAX native version, to check it works, and to fill caches.
_ = func_to_convert(*args)
exported = jax_export.serialize_native(
exported = jax_export.export_native(
func_to_convert,
[core.ShapedArray(a.shape, a.dtype) for a in args],
lowering_platform='tpu',
@ -1608,7 +1608,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"keepalive must be empty"):
"serialization of host_callbacks is not yet implemented"):
jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.))
def f_ordered_jax(x):
@ -1616,7 +1616,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"keepalive must be empty"):
"serialization of host_callbacks is not yet implemented"):
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
def test_tuple_args(self):

View File

@ -372,7 +372,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
for out_shardings in ("missing", None, "P")
)
@jtu.with_mesh([("x", 2)])
def test_grad_pjit(self, in_shardings="missing", out_shardings="None"):
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)

View File

@ -133,7 +133,7 @@ SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
if error_type == SpecErrorType.input and specs is None:
raise TypeError(
f"shard_map in_specs argument must be a pytree of "
"shard_map in_specs argument must be a pytree of "
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
"where `P = jax.sharding.PartitionSpec`?")