mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg. * Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction. * Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction. None of these primitives are differentiable at the moment. PiperOrigin-RevId: 487224934
3121 lines
120 KiB
Python
3121 lines
120 KiB
Python
# Copyright 2020 The JAX Authors.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# https://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||
from functools import partial
|
||
import contextlib
|
||
import os
|
||
import re
|
||
import threading
|
||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
|
||
|
||
from absl import logging
|
||
|
||
import jax
|
||
from jax import lax
|
||
from jax import config
|
||
from jax import core, custom_derivatives
|
||
from jax import linear_util as lu
|
||
from jax import random, tree_util
|
||
from jax import numpy as jnp
|
||
from jax.experimental import maps
|
||
from jax.experimental import pjit
|
||
from jax._src import sharding
|
||
from jax.interpreters import ad
|
||
from jax.interpreters import mlir
|
||
from jax.interpreters import pxla
|
||
from jax.interpreters import xla
|
||
|
||
from jax._src import ad_checkpoint
|
||
from jax._src import ad_util
|
||
from jax._src import api
|
||
from jax._src import api_util
|
||
from jax._src import dispatch
|
||
from jax._src import dtypes
|
||
from jax._src import prng
|
||
from jax._src import random as random_internal
|
||
from jax._src import source_info_util
|
||
from jax._src import util
|
||
from jax._src.lax import control_flow as lax_control_flow
|
||
from jax._src.lax import lax as lax_internal
|
||
from jax._src.lax import linalg as lax_linalg
|
||
from jax._src.lax import slicing as lax_slicing
|
||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||
from jax._src.lib import xla_client
|
||
from jax._src.numpy.ufuncs import logaddexp
|
||
|
||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||
from jax.experimental.jax2tf import shape_poly
|
||
from jax.experimental.jax2tf import impl_no_xla
|
||
|
||
|
||
import numpy as np
|
||
import tensorflow as tf # type: ignore[import]
|
||
|
||
# These don't have public equivalents.
|
||
# pylint: disable=g-direct-tensorflow-import
|
||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
|
||
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
|
||
try:
|
||
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
|
||
except ModuleNotFoundError:
|
||
# This can be removed when TF 2.10 support is no longer needed.
|
||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
|
||
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
|
||
from tensorflow.python.eager import context as tf_context # type: ignore[import]
|
||
# pylint: enable=g-direct-tensorflow-import
|
||
|
||
NameStack = source_info_util.NameStack
|
||
PolyShape = shape_poly.PolyShape
|
||
|
||
# A temporary internal flag, to enable the wrapping of jax.jit functions
|
||
# with tf.function(jit_compile=True). See #7389. This change has triggered a
|
||
# number of failures in TF. We keep this until we are confident that it does
|
||
# not create problems.
|
||
# TODO(b/207464757): figure out why this change breaks test
|
||
_WRAP_JAX_JIT_WITH_TF_FUNCTION = False
|
||
|
||
# The scope name need to be a valid TensorFlow name. See
|
||
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
|
||
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
|
||
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
|
||
|
||
map = util.safe_map
|
||
zip = util.safe_zip
|
||
|
||
|
||
def _sanitize_scope_name(name):
|
||
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
|
||
if not _VALID_SCOPE_REGEX.match(scope_name):
|
||
scope_name = f".{scope_name}"
|
||
return scope_name
|
||
|
||
|
||
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||
TfVal = Any
|
||
DType = Any
|
||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||
|
||
def _is_tfval(v: TfVal) -> bool:
|
||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||
return True
|
||
try:
|
||
# Include all convertible types, even if not supported on accelerators.
|
||
with tf.device("CPU"):
|
||
tf.constant(v)
|
||
return True
|
||
except:
|
||
return False
|
||
|
||
|
||
# The implementation rules for primitives. The rule will be called with the
|
||
# arguments (TfVal) and must return TfVal (or a sequence thereof,
|
||
# if primitive.multiple_results). The exception are primarily the
|
||
# control-flow primitives.
|
||
tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
|
||
|
||
# Some primitive implementation rules need the abstract values of arguments
|
||
# and the results. This is the case for the primitives implemented using
|
||
# _convert_jax_impl and those that need to adjust the shape of the outputs
|
||
# due to missing TF shape inference rules for TFXLA ops. The rules for these
|
||
# primitives should be added to `tf_impl_with_avals`.
|
||
# The abstract value are passed to the implementation as two special kwargs
|
||
# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a
|
||
# core.ShapedArray, or a tuple thereof when primitive.multiple_results).
|
||
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
|
||
|
||
# XLA is not linked in all environments when converting a primitive. If this is
|
||
# the case, we first search for implementation rules for primitives in the
|
||
# following map. These implementations are workarounds, making use of TF ops
|
||
# that do work when XLA is not linked in.
|
||
tf_impl_no_xla = impl_no_xla.tf_impl_no_xla
|
||
|
||
# In order to ensure that JAX picks up the proper user-frame for source
|
||
# locations we will register the TensorFlow source path as an internal
|
||
# path with source_info_util. The typical stack when a JAX primitive
|
||
# conversion happens is:
|
||
# jax2tf.process_primitive (top of stack)
|
||
# jax tracing machinery ...
|
||
# tf.custom_gradient machinery ...
|
||
# jax2tf.converted_fun
|
||
# tf function machinery ...
|
||
# user code invokes the converted function on TF tensors
|
||
#
|
||
# We need to skip over not only JAX internal frames, but TF internal frames
|
||
# also.
|
||
# We register the TensorFlow source path lazily
|
||
_has_registered_tf_source_path = False
|
||
|
||
class _ThreadLocalState(threading.local):
|
||
def __init__(self):
|
||
self.name_stack = ""
|
||
# XLA is not linked in all environments; when converting a primitive, if this
|
||
# variable is disabled, we try harder to use only standard TF ops if they are
|
||
# applicable to the concrete use case; if the resulting conversion path ends up
|
||
# requiring a TFXLA operation, an exception is thrown instead.
|
||
self.enable_xla = True
|
||
|
||
# Keep track if we are inside a call_tf. In that context we disable the
|
||
# safety check that we are not inside JAX transformations.
|
||
self.inside_call_tf = False
|
||
|
||
# Maps dimension variables to TF expressions
|
||
self.shape_env: Sequence[Tuple[str, TfVal]] = ()
|
||
|
||
# Whether to actually include XLA op metadata in the generated TF ops
|
||
# TODO(b/189306134): implement support for XLA metadata
|
||
self.include_xla_op_metadata = False
|
||
|
||
# A cache for the tf.convert_to_tensor for constants. We try to preserve
|
||
# sharing for constants, to enable tf.Graph to take advantage of it.
|
||
# See https://github.com/google/jax/issues/7992.
|
||
self.constant_cache = None # None means that we don't use a cache. We
|
||
# may be outside a conversion scope.
|
||
|
||
_thread_local_state = _ThreadLocalState()
|
||
|
||
def _get_current_name_stack() -> Union[NameStack, str]:
|
||
if config.jax_experimental_name_stack:
|
||
return source_info_util.current_name_stack()
|
||
return _thread_local_state.name_stack
|
||
|
||
@contextlib.contextmanager
|
||
def inside_call_tf():
|
||
# Set the inside_call_tf flag for a context.
|
||
prev = _thread_local_state.inside_call_tf
|
||
_thread_local_state.inside_call_tf = True
|
||
try:
|
||
yield
|
||
finally:
|
||
_thread_local_state.inside_call_tf = prev
|
||
|
||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||
def convert(fun_jax: Callable,
|
||
*,
|
||
polymorphic_shapes=None,
|
||
with_gradient=True,
|
||
enable_xla=True,
|
||
experimental_native_lowering="default") -> Callable:
|
||
"""Lowers `fun_jax` into a function that uses only TensorFlow ops.
|
||
|
||
See
|
||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||
for more details about usage and common problems.
|
||
|
||
Args:
|
||
fun_jax: JAX function to be lowered. Its arguments and return value should be
|
||
JAX arrays, or nested standard Python containers (tuple/list/dict) thereof
|
||
(pytrees).
|
||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||
during lowering.
|
||
|
||
.. warning:: The shape-polymorphic lowering is an experimental feature.
|
||
It is meant to be sound, but it is known to reject some JAX programs
|
||
that are shape polymorphic. The details of this feature can change.
|
||
|
||
It should be `None` (all arguments are monomorphic), a single PolyShape
|
||
or string (applies to all arguments), or a tuple/list of the same length
|
||
as the function arguments. For each argument the shape specification
|
||
should be `None` (monomorphic argument), or a Python object with the
|
||
same pytree structure as the argument.
|
||
See [how optional parameters are matched to
|
||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||
|
||
A shape specification for an array argument should be an object
|
||
`PolyShape(dim0, dim1, ..., dimn)`
|
||
where each `dim` is a dimension specification: a positive integer denoting
|
||
a monomorphic dimension of the given size, or a string denoting a
|
||
dimension variable assumed to range over non-zero dimension sizes, or
|
||
the special placeholder string "_" denoting a monomorphic dimension
|
||
whose size is given by the actual argument. As a shortcut, an Ellipsis
|
||
suffix in the list of dimension specifications stands for a list of "_"
|
||
placeholders.
|
||
|
||
For convenience, a shape specification can also be given as a string
|
||
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
|
||
with surrounding parentheses: "(batch, ...)".
|
||
|
||
The lowering fails if it cannot ensure that the it would produce the same
|
||
sequence of TF ops for any non-zero values of the dimension variables.
|
||
|
||
polymorphic_shapes are only supported for positional arguments; shape
|
||
polymorphism is not supported for keyword arguments.
|
||
|
||
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||
for more details.
|
||
|
||
with_gradient: if set (default), add a tf.custom_gradient to the lowered
|
||
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
||
TensorFlow AD is supported for the output TensorFlow function, and the
|
||
value of the gradient will be JAX-accurate.
|
||
enable_xla: if set (default), use the simplest conversion
|
||
and use XLA TF ops when necessary. These ops are known to create issues
|
||
for the TFLite and TFjs converters. For those cases, unset this parameter
|
||
so the the lowering tries harder to use non-XLA TF ops to lower the
|
||
function and aborts if this is not possible.
|
||
experimental_native_lowering: DO NOT USE, for experimental purposes only.
|
||
The value "default" defers to --jax2tf_default_experimental_native_lowering.
|
||
|
||
Returns:
|
||
A version of `fun_jax` that expects TfVals as arguments (or
|
||
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
||
only TensorFlow ops.
|
||
"""
|
||
if experimental_native_lowering == "default":
|
||
experimental_native_lowering = config.jax2tf_default_experimental_native_lowering
|
||
|
||
if experimental_native_lowering and not enable_xla:
|
||
raise ValueError(
|
||
"experimental_native_lowering is not supported with enable_xla=False")
|
||
api._check_callable(fun_jax)
|
||
fun_name = getattr(fun_jax, "__name__", "unknown")
|
||
name_stack = util.wrap_name(fun_name, "jax2tf")
|
||
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
|
||
# TODO: is there a better way to check if we are inside a transformation?
|
||
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||
# It is Ok to nest convert when we are inside a call_tf
|
||
raise ValueError("convert must be used outside all JAX transformations." +
|
||
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
|
||
|
||
fun_flat_jax, args_flat_tf, in_tree, out_tree_thunk = flatten_fun_jax(fun_jax, args_tf, kwargs_tf)
|
||
# out_tree_thunk will be ready after we call fun_flat_jax below.
|
||
|
||
# Expand the polymorphic_shapes to match the args_flat_tf. The polymorphic_shapes
|
||
# argument refers to positional arguments only.
|
||
if polymorphic_shapes is None or isinstance(polymorphic_shapes, (PolyShape, str)):
|
||
polymorphic_shapes_ = (polymorphic_shapes,) * len(args_tf)
|
||
else:
|
||
if not (isinstance(polymorphic_shapes, Sequence) and len(polymorphic_shapes) == len(args_tf)):
|
||
msg = ("polymorphic_shapes must be a sequence with the same length as the positional argument list "
|
||
f"({len(args_tf)}). Got polymorphic_shapes={repr(polymorphic_shapes)}.")
|
||
raise TypeError(msg)
|
||
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
||
|
||
polymorphic_shapes_flat = tuple(
|
||
api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
|
||
in_tree,
|
||
(polymorphic_shapes_, {k: None for k in kwargs_tf.keys()})))
|
||
|
||
args_and_avals = tuple(map(preprocess_arg_tf,
|
||
range(len(args_flat_tf)), args_flat_tf, polymorphic_shapes_flat))
|
||
args_flat_tf, args_avals_flat = util.unzip2(args_and_avals)
|
||
|
||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(args_avals_flat)
|
||
dim_values, _ = _interpret_fun_jax(get_dim_values_jax,
|
||
args_flat_tf, args_avals_flat, "")
|
||
shape_env = zip(dim_vars, dim_values)
|
||
|
||
try:
|
||
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
|
||
|
||
prev_enable_xla = _thread_local_state.enable_xla
|
||
_thread_local_state.enable_xla = enable_xla
|
||
|
||
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
|
||
# TODO(b/189306134): implement support for XLA metadata
|
||
_thread_local_state.include_xla_op_metadata = False
|
||
|
||
_thread_local_state.shape_env = shape_env
|
||
global _has_registered_tf_source_path
|
||
if not _has_registered_tf_source_path:
|
||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||
_has_registered_tf_source_path = True
|
||
|
||
if with_gradient:
|
||
|
||
@tf.custom_gradient
|
||
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
|
||
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
|
||
args_flat_tf, args_avals_flat,
|
||
name_stack,
|
||
fresh_constant_cache=True,
|
||
experimental_native_lowering=experimental_native_lowering)
|
||
return (tuple(outs_tf),
|
||
make_custom_gradient_fn_tf(
|
||
fun_flat_jax=fun_flat_jax,
|
||
args_flat_tf=args_flat_tf,
|
||
args_avals_flat=args_avals_flat,
|
||
polymorphic_shapes_flat=polymorphic_shapes_flat,
|
||
out_avals=out_avals))
|
||
|
||
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
|
||
else:
|
||
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
|
||
args_flat_tf, args_avals_flat,
|
||
name_stack,
|
||
fresh_constant_cache=True,
|
||
experimental_native_lowering=experimental_native_lowering)
|
||
message = ("The jax2tf-converted function does not support gradients. "
|
||
"Use `with_gradient` parameter to enable gradients")
|
||
# We use PreventGradient, which is propagated through a SavedModel.
|
||
out_flat_tf = [
|
||
tf.raw_ops.PreventGradient(input=o, message=message)
|
||
for o in outs_tf
|
||
]
|
||
finally:
|
||
_thread_local_state.shape_env = ()
|
||
_thread_local_state.enable_xla = prev_enable_xla
|
||
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
|
||
|
||
out_flat_tf = [tf.identity(x, "jax2tf_out") for x in out_flat_tf]
|
||
out_tf = tree_util.tree_unflatten(out_tree_thunk(), out_flat_tf)
|
||
return out_tf
|
||
|
||
return converted_fun_tf
|
||
|
||
|
||
def dtype_of_val(val: TfVal) -> DType:
|
||
"""Computes the TensorFlow dtype using JAX's typing rules.
|
||
|
||
If the value is a tf.Tensor, it starts with its dtype. If the value is a
|
||
constant it uses JAX to infer its dtype. The resulting dtype follows the
|
||
JAX type inference rules, and depends on the value of the
|
||
JAX_ENABLE_X64 flag.
|
||
|
||
See README.md for how 64-bit values are treated.
|
||
"""
|
||
tval, _ = _tfval_to_tensor_jax_dtype(val)
|
||
return tval.dtype
|
||
|
||
# Internals
|
||
|
||
def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal],
|
||
kwargs_tf: Dict[str, TfVal]
|
||
) -> Tuple[Callable, Sequence[TfVal], Any, Callable]:
|
||
"""Wraps the function to take a (flat) list of positional args.
|
||
|
||
jax2tf works better and is simpler when the JAX function takes and returns
|
||
just a tuple of values (no pytrees, no kwargs). This is in part because
|
||
jax.vjp does not support kwargs and we can only set
|
||
tf.custom_gradient on functions with flat arguments and results
|
||
|
||
Returns:
|
||
* the wrapped JAX function taking and returning a flat list of arguments
|
||
* the flat list of TF arguments
|
||
* the in_tree corresponding to the tuple (args_tf, kwargs_tf)
|
||
* a thunk that can be called after the wrapped function has been called
|
||
to return the output pytree.
|
||
"""
|
||
# TODO(necula): technically we should use TF's flattening and unflattening
|
||
# because we are working with TF values.
|
||
args_flat_tf, in_tree = tree_util.tree_flatten((args_tf, kwargs_tf))
|
||
|
||
out_tree_ref = None
|
||
def fun_flat_jax(*args_flat_jax):
|
||
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
|
||
tree_res = fun_jax(*tree_args, **tree_kwargs)
|
||
res_flat_jax, out_tree = tree_util.tree_flatten(tree_res)
|
||
nonlocal out_tree_ref
|
||
assert out_tree_ref is None or out_tree_ref == out_tree
|
||
out_tree_ref = out_tree
|
||
return res_flat_jax
|
||
|
||
if hasattr(fun_jax, "lower"):
|
||
# If the fun_jax is already a jit(f) or pjit(f), we must
|
||
# preserve the lowering function. This will be used in the _lower_native_and_run.
|
||
# We rely on the fact that the lowering is the same for the function
|
||
# taking pytrees, and the one taking flat args.
|
||
def fun_flat_jax_lower(*args_flat_jax):
|
||
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
|
||
lowered = fun_jax.lower(*tree_args, **tree_kwargs)
|
||
out_tree = lowered.out_tree
|
||
nonlocal out_tree_ref
|
||
assert out_tree_ref is None or out_tree_ref == out_tree
|
||
out_tree_ref = out_tree
|
||
return lowered
|
||
setattr(fun_flat_jax, "lower", fun_flat_jax_lower)
|
||
|
||
return fun_flat_jax, args_flat_tf, in_tree, lambda: out_tree_ref
|
||
|
||
def preprocess_arg_tf(arg_idx: int,
|
||
arg_tf: TfVal,
|
||
polymorphic_shape: Optional[str]
|
||
) -> Tuple[TfVal, core.ShapedArray]:
|
||
if not _is_tfval(arg_tf):
|
||
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
|
||
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
|
||
raise TypeError(msg)
|
||
|
||
# May cast the args_flat to JAX types, using JAX's interpretation
|
||
# of types of constants.
|
||
arg_tf, arg_jax_dtype = _tfval_to_tensor_jax_dtype(arg_tf)
|
||
# Name input tensors; do this after we have cast the arguments
|
||
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
|
||
|
||
# Fix the shape for TF1
|
||
tf_arg_shape = np.shape(arg_tf)
|
||
arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
|
||
|
||
arg_aval = shape_poly.arg_aval(arg_shape, arg_jax_dtype, polymorphic_shape)
|
||
return arg_tf, arg_aval
|
||
|
||
|
||
# Prepare the grad_fn for tf.custom_gradient.
|
||
def make_custom_gradient_fn_tf(
|
||
fun_flat_jax: Callable,
|
||
args_flat_tf: Sequence[TfVal],
|
||
polymorphic_shapes_flat: Sequence[str],
|
||
args_avals_flat: Sequence[core.ShapedArray],
|
||
out_avals: Sequence[core.ShapedArray]):
|
||
|
||
def grad_fn_tf(*out_cts_flat_tf: TfVal,
|
||
variables=None):
|
||
if variables:
|
||
raise ValueError(
|
||
"Unexpected variables used in forward pass. "
|
||
"This should not happen for first-order differentiation. "
|
||
f"variables={variables}")
|
||
|
||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
|
||
for out_aval in out_avals) # type: ignore
|
||
vjp_polymorphic_shapes = [
|
||
polymorphic_shapes_flat, out_cts_flat_polymorphic_shapes
|
||
]
|
||
|
||
def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
|
||
# One may think that we can get the pullback while we are converting
|
||
# the main function in the first place. That is problematic, because the
|
||
# pullback may contain captured tracers from the conversion of the
|
||
# main function. Those tracers will confuse the conversion of the
|
||
# pullback. So, we construct the vjp anew and we convert it separately.
|
||
_, pullback_jax = jax.vjp(fun_flat_jax, *args_flat_jax)
|
||
|
||
def fix_out_ct(out_ct_jax, out_ct_aval: core.ShapedArray):
|
||
# If the primal function has outputs of integer or bool types, and if we are
|
||
# under a tf.function context, then TF will pass None in _out_cts_flat
|
||
# in place of these values. We should change these to float0 or
|
||
# else JAX gets unhappy. See issue #6975.
|
||
if out_ct_jax is not None:
|
||
return out_ct_jax
|
||
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"out_ct_jax={out_ct_jax}"
|
||
# Note that out_ct_aval.shape contains dimension variable from the
|
||
# primal function scope. It is Ok to use them here because we
|
||
# use the same shape variables for the VJP function.
|
||
return jnp.zeros(out_ct_aval.shape, dtype=_tf_np_dtype_for_float0)
|
||
|
||
out_cts_fixed_flat = list(map(fix_out_ct, out_cts_flat_jax, out_avals))
|
||
in_cts_flat_jax = pullback_jax(out_cts_fixed_flat)
|
||
|
||
def fix_in_ct(in_ct_jax, arg_aval: core.ShapedArray):
|
||
if jnp.issubdtype(arg_aval.dtype, jnp.inexact):
|
||
return in_ct_jax
|
||
else:
|
||
assert in_ct_jax.dtype == dtypes.float0
|
||
return jnp.zeros(arg_aval.shape, _tf_np_dtype_for_float0)
|
||
|
||
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
|
||
return in_cts_fixed_flat_jax
|
||
|
||
# TODO: enable higher-order gradients
|
||
with tf.name_scope("jax2tf_vjp"):
|
||
in_cts_flat = convert(
|
||
fun_vjp_jax,
|
||
with_gradient=False,
|
||
polymorphic_shapes=vjp_polymorphic_shapes)(args_flat_tf, out_cts_flat_tf)
|
||
return in_cts_flat
|
||
|
||
return grad_fn_tf
|
||
|
||
@contextlib.contextmanager
|
||
def _extended_name_stack(extra_name_stack: Optional[str]):
|
||
if config.jax_experimental_name_stack:
|
||
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
|
||
if extra_name_stack
|
||
else contextlib.nullcontext())
|
||
with name_ctx:
|
||
yield
|
||
return
|
||
prev_name_stack = _thread_local_state.name_stack
|
||
if extra_name_stack:
|
||
if not prev_name_stack:
|
||
_thread_local_state.name_stack = extra_name_stack
|
||
else:
|
||
_thread_local_state.name_stack = util.extend_name_stack(
|
||
_thread_local_state.name_stack, extra_name_stack)
|
||
try:
|
||
yield
|
||
finally:
|
||
_thread_local_state.name_stack = prev_name_stack
|
||
|
||
|
||
def _interpret_fun_jax(
|
||
fun_jax: Callable,
|
||
args_tf: Sequence[TfVal],
|
||
args_avals: Sequence[core.ShapedArray],
|
||
extra_name_stack: Optional[str],
|
||
fresh_constant_cache: bool = False,
|
||
experimental_native_lowering: bool = False
|
||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||
if experimental_native_lowering:
|
||
del extra_name_stack
|
||
return _lower_native_and_run(fun_jax, args_avals, args_tf)
|
||
else:
|
||
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
|
||
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
|
||
with _extended_name_stack(extra_name_stack):
|
||
with core.new_sublevel():
|
||
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
|
||
fresh_constant_cache=fresh_constant_cache)
|
||
|
||
del main
|
||
|
||
return util.unzip2(out_vals)
|
||
|
||
def _lower_native_and_run(fun_jax: Callable,
|
||
args_avals: Sequence[core.ShapedArray],
|
||
args_tf: Sequence[TfVal],
|
||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||
"""Lowers the function using native lowering and then invokes it.
|
||
|
||
Work-in-progress.
|
||
|
||
Uses JAX native lowering to MHLO, and then wraps the result in a
|
||
XlaCallModule TF op. This op does not have backward-compatibility yet.
|
||
|
||
Special care must be taken in presence of shape polymorphism.
|
||
"""
|
||
# Look for shape polymorphism
|
||
# For each arg, map axis idx to dimension variable name
|
||
abstracted_axes: Sequence[Dict[int, str]] = []
|
||
# For each dimension variable, encode how to compute its value from the
|
||
# shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1].
|
||
# Note: We assume that lowering will introduce dim args in the order in which
|
||
# dim variables are first seen when scanning the explicit arguments
|
||
# in order and then scanning their shapes for dim variables.
|
||
dim_args_spec: List[str] = []
|
||
dim_vars_seen: Set[str] = set()
|
||
for arg_idx, aval in enumerate(args_avals):
|
||
one_abstract_axes = {}
|
||
for axis_idx, d in enumerate(aval.shape):
|
||
if not core.is_constant_dim(d):
|
||
d_var = d.to_var()
|
||
if d_var is None:
|
||
raise ValueError(f"Only simple dimension variables supported: {aval.shape}")
|
||
if not d_var in dim_vars_seen:
|
||
dim_args_spec.append(f"{arg_idx}.{axis_idx}")
|
||
dim_vars_seen.add(d_var)
|
||
one_abstract_axes[axis_idx] = d_var
|
||
abstracted_axes.append(one_abstract_axes)
|
||
|
||
if any(abstracted_axes):
|
||
if not config.jax_dynamic_shapes:
|
||
raise ValueError(
|
||
"Found shape polymorphism but --jax_dynamic_shapes is not set")
|
||
abstracted_axes = tuple(abstracted_axes)
|
||
else:
|
||
abstracted_axes = None # type: ignore
|
||
|
||
arg_specs_jax = [
|
||
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
|
||
for aval in args_avals
|
||
]
|
||
# TODO: specify the backend for experimental_native_lowering
|
||
backend = jax.default_backend()
|
||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
|
||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||
# we need to pass the abstracted axes.
|
||
fun_jax_lower = jax.jit(fun_jax, backend=backend,
|
||
keep_unused=True, # TODO: allow dropping unused
|
||
abstracted_axes=abstracted_axes).lower
|
||
else:
|
||
fun_jax_lower = fun_jax.lower
|
||
lowered = fun_jax_lower(*arg_specs_jax)._lowering
|
||
mhlo_module = lowered.mhlo()
|
||
if logging.vlog_is_on(3):
|
||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||
logging.vlog(3, "XlaCallModule %s", mhlo_module_text)
|
||
|
||
mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
|
||
# Figure out the result types and shapes
|
||
if "global_out_avals" in lowered.compile_args:
|
||
# This is currently the case for pjit
|
||
out_avals = lowered.compile_args["global_out_avals"]
|
||
else:
|
||
out_avals = lowered.compile_args["out_avals"]
|
||
if lowered.compile_args["host_callbacks"]:
|
||
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
|
||
|
||
# TODO(necula): handle d being InDBIdx
|
||
out_shapes = tuple(
|
||
tuple(d if type(d) is int else None
|
||
for d in out_aval.shape)
|
||
for out_aval in out_avals)
|
||
|
||
def _out_type(jax_type):
|
||
if jax_type == dtypes.float0:
|
||
return dtypes.bool_
|
||
return jax_type
|
||
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)
|
||
|
||
# Apply the shardings on arguments and results for pjit. This is redundant
|
||
# because the mhlo_module_text will already contain the shardings, but it
|
||
# makes it easier for tools like the TPU inference converter to see the
|
||
# sharding without digging into the `module` attribute of the `XlaCallModule`
|
||
# op, in the same way as it is done for the legacy jax2tf conversion.
|
||
if "in_shardings" in lowered.compile_args:
|
||
args_tf = tuple(
|
||
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
|
||
res = tfxla.call_module(
|
||
args_tf,
|
||
module=mhlo_serialized_module,
|
||
Tout=out_types,
|
||
Sout=out_shapes,
|
||
dim_args_spec=dim_args_spec)
|
||
if "out_shardings" in lowered.compile_args:
|
||
res = list(map(_shard_value, res, out_avals, lowered.compile_args["out_shardings"]))
|
||
|
||
# Convert the results to the needed TF types
|
||
def _convert_res(res_val, res_jax_type):
|
||
conversion_dtype = _to_tf_dtype(res_jax_type)
|
||
if conversion_dtype != res_jax_type:
|
||
return tf.cast(res_val, conversion_dtype)
|
||
else:
|
||
return res_val
|
||
|
||
res = tuple(
|
||
_convert_res(res_val, out_aval.dtype)
|
||
for res_val, out_aval in zip(res, out_avals))
|
||
return res, out_avals
|
||
|
||
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||
in_vals: Sequence[TfVal],
|
||
fresh_constant_cache: bool = False
|
||
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
|
||
try:
|
||
prev_constant_cache = _thread_local_state.constant_cache
|
||
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
|
||
# Start a new cache, so that we don't share constants across tf.function
|
||
# boundaries.
|
||
if fresh_constant_cache:
|
||
_thread_local_state.constant_cache = {}
|
||
|
||
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
fun.call_wrapped(*in_vals)
|
||
finally:
|
||
if prev_constant_cache is not None and not fresh_constant_cache:
|
||
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
|
||
# Delete the newly added keys
|
||
for k in newly_added_keys:
|
||
del prev_constant_cache[k]
|
||
_thread_local_state.constant_cache = prev_constant_cache
|
||
return out_vals
|
||
|
||
def _convert_jax_impl(impl_jax: Callable, *,
|
||
multiple_results=True,
|
||
with_physical_avals=False,
|
||
extra_name_stack: Optional[str] = None) -> Callable:
|
||
"""Convert the JAX implementation of a primitive.
|
||
|
||
Args:
|
||
impl_jax: typically the impl-rule for a primitive, with signature
|
||
`(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
|
||
a primitive in terms of other primitives.
|
||
multiple_results: whether `impl_jax` returns a sequence of results.
|
||
extra_name_stack: additional element to add to the name stack for the
|
||
converted ops.
|
||
|
||
Returns:
|
||
a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs)
|
||
-> Sequence[TfVal]`.
|
||
"""
|
||
|
||
def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,
|
||
**kwargs) -> Sequence[TfVal]:
|
||
|
||
if with_physical_avals:
|
||
_in_avals = map(_jax_physical_aval, _in_avals)
|
||
_out_aval = _jax_physical_aval(_out_aval)
|
||
|
||
# We wrap the impl_jax to always return a tuple of results.
|
||
def impl_multiple_results_jax(*args_jax):
|
||
results_jax = impl_jax(*args_jax, **kwargs)
|
||
return results_jax if multiple_results else [results_jax]
|
||
|
||
results_tf, _ = _interpret_fun_jax(
|
||
impl_multiple_results_jax, args_tf, _in_avals,
|
||
extra_name_stack)
|
||
return results_tf if multiple_results else results_tf[0]
|
||
|
||
return wrapped_tf
|
||
|
||
|
||
@lu.transformation
|
||
def _interpret_subtrace(main: core.MainTrace,
|
||
in_avals: Sequence[core.ShapedArray],
|
||
*in_vals: TfVal):
|
||
trace = TensorFlowTrace(main, core.cur_sublevel())
|
||
in_tracers = tuple(
|
||
TensorFlowTracer(trace, val, aval)
|
||
for val, aval in zip(in_vals, in_avals))
|
||
outs = yield in_tracers, {} # type: Sequence[TfVal]
|
||
out_tracers: Iterable[TensorFlowTracer] = (
|
||
map(trace.full_raise, outs)) # type: ignore
|
||
out_vals_with_avals: Sequence[Tuple[TfVal, core.ShapedArray]] = (
|
||
tuple((t.val, t.aval) for t in out_tracers))
|
||
yield out_vals_with_avals
|
||
|
||
|
||
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
|
||
extra_name_stack: Optional[str]) -> Sequence[TfVal]:
|
||
"""Evaluates a Jaxpr with tf.Tensor arguments.
|
||
|
||
The output is a sequence of TfVal, suitable for use with TF.
|
||
"""
|
||
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
|
||
args_tf, jaxpr.in_avals, extra_name_stack)
|
||
return outs_tf
|
||
|
||
|
||
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
|
||
"""Converts JAX avals from logical to physical, if relevant.
|
||
|
||
JAX might have avals whose logical vs physical shape/dtype may
|
||
differ, and only the physical view is expected to possibly
|
||
relate to TF. TF impl rules should operate on the physical form.
|
||
|
||
A JAX logical aval might even correspond, in principle, to several
|
||
physical avals, but we don't support those here. Instead we assert
|
||
there is only one and return it.
|
||
"""
|
||
if core.is_opaque_dtype(aval.dtype):
|
||
aval, = aval.dtype._rules.physical_avals(aval)
|
||
return aval
|
||
return aval
|
||
|
||
def _jax_physical_dtype(dtype):
|
||
# assuming () is a fine stand-in shape
|
||
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
|
||
|
||
|
||
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
|
||
|
||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||
aval = _jax_physical_aval(aval)
|
||
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
|
||
aval.shape)) # type: ignore[attr-defined]
|
||
|
||
# In the TF world, we represent float0 as zeros of this type.
|
||
_tf_np_dtype_for_float0 = np.int32
|
||
|
||
def _to_tf_dtype(jax_dtype):
|
||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||
# due to float0 and 64-bit behavior.
|
||
try:
|
||
jax_dtype = _jax_physical_dtype(jax_dtype)
|
||
except TypeError:
|
||
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
|
||
# tf.float32), so there is no physical dtype anyway
|
||
pass
|
||
if jax_dtype == dtypes.float0:
|
||
jax_dtype = _tf_np_dtype_for_float0
|
||
return tf.dtypes.as_dtype(jax_dtype)
|
||
|
||
|
||
def _to_jax_dtype(tf_dtype):
|
||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||
# due to float0 and 64-bit behavior.
|
||
return dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
|
||
|
||
|
||
def _maybe_decode_gda(gda_or_py_object: Any):
|
||
"""Convert GlobalDeviceArray into numpy object."""
|
||
if isinstance(gda_or_py_object, GlobalDeviceArray):
|
||
if jax.process_count() != 1:
|
||
raise RuntimeError("GlobalDeviceArray does not support multi-process"
|
||
f" currently. Process num = {jax.process_count()}")
|
||
return gda_or_py_object._value
|
||
return gda_or_py_object
|
||
|
||
|
||
def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||
jax_dtype: Optional[DType] = None,
|
||
memoize_constants=False) -> Tuple[TfVal, DType]:
|
||
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
|
||
|
||
If `jax_dtype` is missing, uses JAX typing rules.
|
||
See README.md for details regarding 64-bit values.
|
||
|
||
Args:
|
||
val: a scalar, ndarray, tf.Tensor, or tf.Variable
|
||
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
|
||
rules for constants.
|
||
memoize_constants: whether to memoize TF constants. We can't do this
|
||
everywhere, we may be outside of a conversion scope.
|
||
|
||
Returns:
|
||
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
|
||
"""
|
||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
|
||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
|
||
return tf.cast(val, conversion_dtype), jax_dtype
|
||
else:
|
||
return val, jax_dtype
|
||
else: # A constant
|
||
jax_dtype = jax_dtype or xla.abstractify(val).dtype
|
||
# TODO(document): We assume that the value of a constant does not
|
||
# change through the scope of the function. But it may be an ndarray, ...
|
||
# JAX has the same problem when generating HLO.
|
||
const_key = (id(val), jax_dtype)
|
||
# Since we use id(val) as a cache key, we have to make sure that we keep
|
||
# the previous `val` alive. Otherwise, for an ndarray, it can get garbage
|
||
# collected and reused for a different value, which would create correctness
|
||
# issues. We keep the `val` alive by storing in the cache the pair
|
||
# `(val, tf_val)`.
|
||
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
|
||
if do_memoize:
|
||
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
|
||
else:
|
||
tf_val = None
|
||
if tf_val is None:
|
||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||
# The float0 type is not known to TF.
|
||
if jax_dtype == dtypes.float0:
|
||
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
|
||
tf_val = tf.convert_to_tensor(
|
||
_maybe_decode_gda(val), dtype=conversion_dtype)
|
||
if do_memoize:
|
||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||
return tf_val, jax_dtype
|
||
|
||
|
||
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
|
||
assert all(map(lambda x: x is not None, shape)), (
|
||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||
if dtype is not None:
|
||
shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape
|
||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||
eval_shape_jax, dim_avals = shape_poly.get_shape_evaluator(dim_vars, shape)
|
||
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
|
||
dim_values, dim_avals, "") # type: ignore
|
||
return shape_values_tf
|
||
|
||
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
|
||
"""Asserts that shape matches x.shape in the known dimensions and has
|
||
dimension polynomials elsewhere."""
|
||
# Ensures that the shape does not contain None; it should contain polynomials
|
||
assert (len(x.shape) == len(shape) and
|
||
all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or
|
||
core.is_constant_dim(sd) and xd == sd)
|
||
for xd, sd in zip(x.shape, shape))), \
|
||
f"Shape {shape} does not match x.shape {x.shape}"
|
||
|
||
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
|
||
# pylint: disable=assigning-non-slot
|
||
|
||
|
||
class TensorFlowTracer(core.Tracer):
|
||
"""Tracer class that boxes a TF value and a JAX abstract value.
|
||
|
||
In addition to the TF value we carry the JAX abstract value because
|
||
there are some cases when it cannot be recovered from the value:
|
||
when we are converting with polymorphic shapes or when the JAX aval
|
||
has a custom element type. In these cases the shape of the value may
|
||
have dimensions set to `None`, or it may only correspond to the JAX
|
||
"physical" (TF/lowering-compatible) shape, so the JAX abstract value
|
||
may contain more precise information.
|
||
|
||
When the value has a partially-known shape, the dimensions marked as `None`
|
||
must correspond to non-constant dimensions in the abstract value.
|
||
|
||
See README.md for details.
|
||
"""
|
||
# val: TfVal
|
||
# _aval: core.ShapedArray
|
||
__slots__ = ["val", "_aval"]
|
||
|
||
def __init__(self, trace: "TensorFlowTrace", val: TfVal,
|
||
aval: core.AbstractValue):
|
||
self._trace = trace
|
||
self._aval = aval
|
||
phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type]
|
||
|
||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||
val_shape = val.shape
|
||
|
||
if config.jax_enable_checks:
|
||
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than val_shape={val_shape}"
|
||
# To compare types, we must handle float0 in JAX and x64 in TF
|
||
if phys_aval.dtype == dtypes.float0:
|
||
assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}"
|
||
else:
|
||
assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}"
|
||
|
||
for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined]
|
||
if val_dim is None:
|
||
assert shape_poly.is_poly_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||
elif not shape_poly.is_poly_dim(aval_dim):
|
||
assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||
else:
|
||
# We have a TF value with known shape, and the abstract shape is a shape variable.
|
||
try:
|
||
aval_int = int(_eval_shape([aval_dim])) # type: ignore
|
||
except (TypeError, KeyError):
|
||
continue
|
||
assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
|
||
|
||
self.val = _tfval_to_tensor_jax_dtype(val,
|
||
phys_aval.dtype,
|
||
memoize_constants=True)[0] # type: ignore[attr-defined]
|
||
|
||
@property
|
||
def aval(self):
|
||
return self._aval
|
||
|
||
def full_lower(self):
|
||
return self
|
||
|
||
|
||
class TensorFlowTrace(core.Trace):
|
||
"""Trace class that underlies the jax2tf transformation.
|
||
|
||
We are going to ensure that jax2tf.convert is never nested inside other
|
||
transformations. This is sufficient for intended use cases (converting
|
||
fully-transformed JAX code). It also simplifies our job because we do not have
|
||
to handle situations where we apply primitives on a mix of TF values and
|
||
JAX tracers from an outer transformation. E.g., for addition both the TF
|
||
values
|
||
and the JAX tracers have an override and they get confused if they see values
|
||
from the other world.
|
||
|
||
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
|
||
higher-order control-flow primitives we invoke recursively
|
||
_interpret_fun on the body of the conditional, which will create a nested TFT.
|
||
|
||
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
|
||
those will introduce their own MainTrace, and any operations involving those
|
||
will be done on those traces, i.e., not a concern for TFT.
|
||
"""
|
||
def pure(self, val: TfVal) -> TensorFlowTracer:
|
||
"""Lifts a non-Tracer into the TensorFlowTracer.
|
||
|
||
This function may be called by way of trace.full_raise.
|
||
"""
|
||
if hasattr(val, "__jax_array__"):
|
||
val = val.__jax_array__()
|
||
if isinstance(val, TensorFlowTracer):
|
||
return val
|
||
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True)
|
||
return TensorFlowTracer(
|
||
self, val, core.ShapedArray(tf_val.shape, jax_dtype,
|
||
weak_type=dtypes.is_weakly_typed(val)))
|
||
|
||
def lift(self, val: core.Tracer) -> TensorFlowTracer:
|
||
# This would be called when we need to raise a tracer from a lower-level
|
||
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
|
||
# inside another transform, there are no lower-level main traces.
|
||
assert False
|
||
|
||
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
|
||
# This is called when we need to raise a tracer from the same main,
|
||
# but a lower sublevel. This could come from a nested jit.
|
||
return TensorFlowTracer(self, val.val, val._aval)
|
||
|
||
def process_primitive(self, primitive: core.Primitive,
|
||
tracers: Sequence[TensorFlowTracer],
|
||
params) -> TensorFlowTracer:
|
||
impl, impl_needs_avals = self.get_primitive_impl(primitive)
|
||
args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||
# This is a bit conservative, doing abstract_eval even in op-by-op execution
|
||
# but we needed it for, e.g., shape_polymorphism where only JAX's
|
||
# abstract evaluation rules can properly track polymorphic shapes.
|
||
# Unfortunately under op-by-op execution this is a rare occasion where we
|
||
# need abstract evaluation.
|
||
out_aval, _ = primitive.abstract_eval(*args_avals, **params)
|
||
args_tf: Sequence[TfVal] = [t.val for t in tracers]
|
||
def invoke_impl() -> TfVal:
|
||
if impl_needs_avals:
|
||
return impl(
|
||
*args_tf,
|
||
_in_avals=args_avals, # type: ignore
|
||
_out_aval=out_aval,
|
||
**params)
|
||
else:
|
||
return impl(*args_tf, **params)
|
||
|
||
current_name_stack = _get_current_name_stack()
|
||
if config.jax_experimental_name_stack:
|
||
# We don't use `str(name_stack)` because it uses parentheses for
|
||
# transformations, which aren't allowed in `name_scope`.
|
||
scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr]
|
||
else:
|
||
scope = str(current_name_stack)
|
||
# We need to add a '/' to the name stack string to force `tf.name_scope`
|
||
# to interpret it as an absolute scope, not a relative scope.
|
||
scope = scope + '/'
|
||
name_scope = (
|
||
tf.name_scope(_sanitize_scope_name(scope)) if
|
||
config.jax_experimental_name_stack else contextlib.nullcontext())
|
||
|
||
with name_scope:
|
||
if _thread_local_state.include_xla_op_metadata:
|
||
op_metadata = xla.make_op_metadata(primitive, params,
|
||
name_stack=current_name_stack,
|
||
source_info=source_info_util.current())
|
||
op_metadata_proto = xla_data_pb2.OpMetadata(
|
||
op_type=op_metadata.op_type,
|
||
op_name=op_metadata.op_name,
|
||
source_file=op_metadata.source_file,
|
||
source_line=op_metadata.source_line
|
||
)
|
||
with tf_ops.get_default_graph()._attr_scope(
|
||
{"_XlaOpMetadata": attr_value_pb2.AttrValue(
|
||
s=op_metadata_proto.SerializeToString())}):
|
||
val_out = invoke_impl()
|
||
else:
|
||
val_out = invoke_impl()
|
||
|
||
if primitive.multiple_results:
|
||
out = [
|
||
TensorFlowTracer(self, v, a)
|
||
for v, a in zip(val_out, out_aval)
|
||
] # type: ignore
|
||
else:
|
||
out = TensorFlowTracer(self, val_out, out_aval) # type: ignore
|
||
|
||
# Check that the impl rule returned a value of expected shape and dtype
|
||
# TODO: adapt this to match polymorphic shapes
|
||
if config.jax_enable_checks:
|
||
if primitive.multiple_results:
|
||
for o, expected_aval in zip(out, out_aval): # type: ignore
|
||
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
|
||
f"{primitive}: out.aval = {o.aval}; expected {expected_aval}")
|
||
else:
|
||
assert out.aval == out_aval, ( # type: ignore
|
||
f"{primitive}: out.aval = {out.aval}; expected {out_aval}"
|
||
) # type: ignore
|
||
return out # type: ignore
|
||
|
||
def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun,
|
||
tracers: Sequence[TensorFlowTracer], params):
|
||
assert call_primitive.multiple_results
|
||
vals: Sequence[TfVal] = [t.val for t in tracers]
|
||
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
||
extra_name_stack = None
|
||
if call_primitive == core.named_call_p:
|
||
extra_name_stack = util.wrap_name(params["name"], "named")
|
||
elif call_primitive == xla.xla_call_p:
|
||
extra_name_stack = util.wrap_name(params["name"], "jit")
|
||
with _extended_name_stack(extra_name_stack):
|
||
with core.new_sublevel():
|
||
if call_primitive == core.named_call_p:
|
||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||
vals_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
interpreted_fun.call_wrapped(*vals)
|
||
elif call_primitive == xla.xla_call_p:
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
# Make a nested tf.function(jit_compile=True)
|
||
store_tf_res_avals: Sequence[core.ShapedArray] = []
|
||
def f_tf(*tf_args):
|
||
nonlocal store_tf_res_avals
|
||
tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
_call_wrapped_with_new_constant_cache(interpreted_fun, tf_args,
|
||
fresh_constant_cache=False)
|
||
tf_res_vals, tf_res_avals = util.unzip2(tf_res_out)
|
||
store_tf_res_avals = tf_res_avals
|
||
return tf_res_vals
|
||
tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals)
|
||
vals_out = zip(tf_vals_out, store_tf_res_avals)
|
||
else:
|
||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||
else:
|
||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
||
|
||
def post_process_call(self, call_primitive: core.Primitive,
|
||
out_tracers: Sequence[TensorFlowTracer], params):
|
||
# We encountered a call primitive whose result (out_tracers) include
|
||
# TensorFlowTracer that were not passed through its arguments (captured from
|
||
# the environment).
|
||
vals = tuple(t.val for t in out_tracers)
|
||
main = self.main
|
||
|
||
def todo(vals: Sequence[TfVal]):
|
||
# TODO: is name_stack correct?
|
||
trace = TensorFlowTrace(main, core.cur_sublevel())
|
||
return [
|
||
TensorFlowTracer(trace, v, out_tracer.aval)
|
||
for v, out_tracer in zip(vals, out_tracers)
|
||
]
|
||
|
||
return vals, todo
|
||
|
||
def process_map(self, map_primitive, f, tracers, params):
|
||
raise NotImplementedError("process_map")
|
||
|
||
def post_process_map(self, map_primitive, out_tracers, params):
|
||
raise NotImplementedError("post_process_map")
|
||
|
||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||
# Drop the custom differentiation rule and act like a call primitive. This
|
||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||
# there are no more JAX differentiation transformations to be applied.
|
||
del jvp # Unused.
|
||
return self.process_call(core.call_p, fun, tracers, {})
|
||
|
||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||
# Drop the custom differentiation rule and act like a call primitive. This
|
||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||
# there are no more JAX differentiation transformations to be applied.
|
||
del fwd, bwd, out_trees # Unused.
|
||
return self.process_call(core.call_p, fun, tracers, {})
|
||
|
||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def post_process_custom_vjp_call_fwd(self, *_, **__):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
|
||
# Returns the primitive implementation and whether the implementation
|
||
# takes abstract values (see definition of tf_impl_with_avals)
|
||
if not _thread_local_state.enable_xla:
|
||
try:
|
||
return tf_impl_no_xla[p], True # Always require avals.
|
||
except KeyError:
|
||
pass
|
||
try:
|
||
return tf_impl[p], False
|
||
except KeyError:
|
||
try:
|
||
return tf_impl_with_avals[p], True
|
||
except KeyError as err:
|
||
msg = "TensorFlow interpretation rule for '{}' not implemented"
|
||
raise NotImplementedError(msg.format(p)) from err
|
||
|
||
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||
assert False, f"Encountered unexpected primitive {p}"
|
||
|
||
|
||
# Call primitives are inlined
|
||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]:
|
||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||
|
||
# Primitives that are not yet implemented must be explicitly declared here.
|
||
tf_not_yet_impl = [
|
||
"clz",
|
||
"igamma_grad_a",
|
||
"random_gamma_grad",
|
||
"reduce_precision",
|
||
"reduce_xor",
|
||
"schur",
|
||
"closed_call",
|
||
"unreachable",
|
||
"bint",
|
||
"getslice",
|
||
"full_to_shard",
|
||
"shard_to_full",
|
||
"pure_callback",
|
||
"for",
|
||
"inspect_sharding",
|
||
"iota_32x2_shape",
|
||
|
||
# Not high priority?
|
||
"after_all",
|
||
"all_to_all",
|
||
"assert",
|
||
"create_token",
|
||
"custom_transpose_call",
|
||
"custom_vmap_call",
|
||
"infeed",
|
||
"linear_call",
|
||
"outfeed",
|
||
"pmax_p",
|
||
"pmin",
|
||
"ppermute",
|
||
"psum",
|
||
"pmax",
|
||
"pgather",
|
||
"reduce_scatter",
|
||
"axis_index",
|
||
"pdot",
|
||
"all_gather",
|
||
"lu_pivots_to_permutation",
|
||
"xla_pmap",
|
||
"geqrf",
|
||
"householder_product",
|
||
"hessenberg",
|
||
"tridiagonal",
|
||
"eigh_jacobi",
|
||
]
|
||
|
||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
|
||
|
||
|
||
def _add(x: TfVal, y: TfVal) -> TfVal:
|
||
return tf.raw_ops.AddV2(x=x, y=y)
|
||
|
||
|
||
tf_impl[ad_util.add_jaxvals_p] = _add
|
||
tf_impl[dispatch.device_put_p] = lambda x, device=None: x
|
||
tf_impl[lax_internal.copy_p] = lambda x: x
|
||
|
||
def _neg(x: TfVal) -> TfVal:
|
||
if x.dtype.is_unsigned:
|
||
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[x.dtype]
|
||
x_signed = tf.cast(x, signed_dtype)
|
||
res_signed = tf.math.negative(x_signed)
|
||
return tf.cast(res_signed, x.dtype)
|
||
else:
|
||
return tf.math.negative(x)
|
||
|
||
tf_impl[lax.neg_p] = _neg
|
||
|
||
|
||
def _sign(x: TfVal) -> TfVal:
|
||
if x.dtype.is_unsigned:
|
||
# TF and XLA do not support tf.math.sign for unsigned types.
|
||
return tf.where(
|
||
tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype),
|
||
tf.constant(1, dtype=x.dtype))
|
||
else:
|
||
return tf.math.sign(x)
|
||
|
||
|
||
tf_impl[lax.sign_p] = _sign
|
||
tf_impl[lax.floor_p] = tf.math.floor
|
||
tf_impl[lax.ceil_p] = tf.math.ceil
|
||
|
||
|
||
def _round(operand, *, rounding_method,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
|
||
# JAX uses a single HLO op Round here
|
||
sign = _sign(operand)
|
||
operand *= sign
|
||
floor = tf.math.floor(operand)
|
||
operand -= floor
|
||
cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype))
|
||
return sign * (
|
||
tf.where(cond, tf.constant(np.array(1), operand.dtype),
|
||
tf.math.round(operand)) + floor)
|
||
else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
||
return tf.math.round(operand)
|
||
|
||
tf_impl_with_avals[lax.round_p] = _round
|
||
tf_impl[lax.nextafter_p] = tf.math.nextafter
|
||
|
||
|
||
def _population_count(x):
|
||
orig_dtype = x.dtype
|
||
return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype)
|
||
|
||
|
||
tf_impl[lax.population_count_p] = _population_count
|
||
tf_impl[lax.is_finite_p] = tf.math.is_finite
|
||
|
||
|
||
def _abs(x: TfVal) -> TfVal:
|
||
# TF and XLA do not support tf.math.abs for unsigned types.
|
||
return tf.math.abs(x) if not x.dtype.is_unsigned else x
|
||
|
||
|
||
tf_impl[lax.abs_p] = _abs
|
||
tf_impl[lax.pow_p] = tf.math.pow
|
||
|
||
|
||
def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Follows the implementation in lax._integer_pow_translation_rule
|
||
if y == 0:
|
||
return tf.broadcast_to(
|
||
tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape))
|
||
is_reciprocal = y < 0
|
||
if is_reciprocal:
|
||
y = -y
|
||
acc = None
|
||
while y > 0:
|
||
if y & 1:
|
||
acc = x if acc is None else tf.math.multiply(acc, x)
|
||
y >>= 1
|
||
if y > 0:
|
||
x = tf.math.multiply(x, x)
|
||
return tf.math.reciprocal(acc) if is_reciprocal else acc
|
||
|
||
|
||
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
|
||
tf_impl[lax.exp_p] = tf.math.exp
|
||
tf_impl[lax.expm1_p] = tf.math.expm1
|
||
tf_impl[lax.log_p] = tf.math.log
|
||
tf_impl[lax.log1p_p] = tf.math.log1p
|
||
tf_impl[lax.tan_p] = tf.math.tan
|
||
tf_impl[lax.tanh_p] = tf.math.tanh
|
||
tf_impl[lax.sin_p] = tf.math.sin
|
||
tf_impl[lax.sinh_p] = tf.math.sinh
|
||
tf_impl[lax.cos_p] = tf.math.cos
|
||
tf_impl[lax.cosh_p] = tf.math.cosh
|
||
tf_impl_with_avals[lax.acos_p] = _convert_jax_impl(
|
||
lax_internal.acos_impl, multiple_results=False)
|
||
tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
|
||
lax_internal.asin_impl, multiple_results=False)
|
||
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
|
||
lax_internal.atan_impl, multiple_results=False)
|
||
|
||
# TODO(phawkins): use tf.math.sigmoid here instead.
|
||
tf_impl_with_avals[lax.logistic_p] = _convert_jax_impl(
|
||
lax_internal.logistic_impl, multiple_results=False)
|
||
|
||
def _atan2(y, x, **kwargs):
|
||
if x.dtype.is_complex or y.dtype.is_complex:
|
||
complex_component_dtype = {
|
||
tf.complex64: tf.float32,
|
||
tf.complex128: tf.float64
|
||
}.get(y.dtype)
|
||
zero = tf.constant(0, complex_component_dtype)
|
||
one = tf.constant(1, complex_component_dtype)
|
||
i = tf.complex(zero, one)
|
||
return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y))
|
||
else:
|
||
return tf.math.atan2(y, x)
|
||
|
||
|
||
tf_impl[lax.atan2_p] = _atan2
|
||
tf_impl[lax.acosh_p] = tf.math.acosh
|
||
tf_impl[lax.atanh_p] = tf.math.atanh
|
||
tf_impl[lax.asinh_p] = tf.math.asinh
|
||
|
||
tf_impl[lax.sqrt_p] = tf.math.sqrt
|
||
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
|
||
|
||
def _cbrt(x):
|
||
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)
|
||
|
||
tf_impl[lax.cbrt_p] = _cbrt
|
||
|
||
tf_impl[lax.lgamma_p] = tf.math.lgamma
|
||
tf_impl[lax.digamma_p] = tf.math.digamma
|
||
tf_impl[lax.igamma_p] = tf.math.igamma
|
||
tf_impl[lax.igammac_p] = tf.math.igammac
|
||
tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc
|
||
tf_impl[lax.erf_p] = tf.math.erf
|
||
tf_impl[lax.erfc_p] = tf.math.erfc
|
||
tf_impl[lax.erf_inv_p] = tf.math.erfinv
|
||
tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
|
||
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
|
||
|
||
tf_impl[lax.complex_p] = tf.complex
|
||
|
||
|
||
def _conj(x, **kwargs):
|
||
# The only dtypes that are allowed are: float32, float64, complex64, and
|
||
# complex128.
|
||
if x.dtype == tf.float32:
|
||
return tf.cast(x, tf.complex64)
|
||
elif x.dtype == tf.float64:
|
||
return tf.cast(x, tf.complex128)
|
||
else:
|
||
return tf.math.conj(x)
|
||
|
||
|
||
tf_impl[lax.conj_p] = _conj
|
||
tf_impl[lax.real_p] = tf.math.real
|
||
tf_impl[lax.imag_p] = tf.math.imag
|
||
|
||
tf_impl[lax.add_p] = _add
|
||
tf_impl[lax.sub_p] = tf.math.subtract
|
||
tf_impl[lax.mul_p] = tf.math.multiply
|
||
|
||
|
||
def _iota(*, dtype, shape, dimension):
|
||
dtype = _to_tf_dtype(dtype)
|
||
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
|
||
# TODO(mattjj, necula): improve tf.range dtype handling
|
||
shape_tf = _eval_shape(shape)
|
||
vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32)
|
||
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
|
||
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype)
|
||
|
||
|
||
tf_impl[lax.iota_p] = _iota
|
||
|
||
|
||
def _div(lhs, rhs):
|
||
if lhs.dtype.is_integer:
|
||
quotient = tf.math.floordiv(lhs, rhs)
|
||
select = tf.math.logical_and(
|
||
tf.not_equal(_sign(lhs), _sign(rhs)),
|
||
tf.not_equal(tf.math.floormod(lhs, rhs), 0))
|
||
return tf.where(select, quotient + 1, quotient)
|
||
else:
|
||
return tf.math.truediv(lhs, rhs)
|
||
|
||
|
||
def _rem(lhs, rhs):
|
||
return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs))
|
||
|
||
|
||
tf_impl[lax.div_p] = _div
|
||
tf_impl[lax.rem_p] = _rem
|
||
|
||
|
||
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,) -> TfVal:
|
||
# For complex numbers use lexicographic ordering, like JAX
|
||
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
|
||
return _convert_jax_impl(
|
||
partial(lax_internal._minmax_complex_lowering,
|
||
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
||
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
elif x.dtype.as_numpy_dtype == np.bool_:
|
||
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
|
||
else:
|
||
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
||
|
||
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
|
||
# For reducers we will need min/max for scalars only. In that case we
|
||
# can construct the AbstractValues outselves, even in the presence of
|
||
# shape polymorphism.
|
||
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
|
||
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
|
||
return _minmax(x, y, is_min=is_min,
|
||
_in_avals=[aval, aval], _out_aval=aval)
|
||
|
||
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
|
||
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
|
||
|
||
# Map from TF signed types to TF unsigned types.
|
||
_SIGNED_TO_UNSIGNED_TABLE = {
|
||
tf.int8: tf.uint8,
|
||
tf.int16: tf.uint16,
|
||
tf.int32: tf.uint32,
|
||
tf.int64: tf.uint64,
|
||
}
|
||
|
||
# Map from TF unsigned types to TF signed types.
|
||
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}
|
||
|
||
|
||
# Note: Bitwise operations only yield identical results on unsigned integers!
|
||
# pylint: disable=protected-access
|
||
def _shift_right_arithmetic_raw(x, y):
|
||
if x.dtype.is_unsigned:
|
||
assert x.dtype == y.dtype
|
||
orig_dtype = x.dtype
|
||
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
|
||
x = tf.cast(x, signed_dtype)
|
||
y = tf.cast(y, signed_dtype)
|
||
res = tf.bitwise.right_shift(x, y)
|
||
return tf.cast(res, orig_dtype)
|
||
else:
|
||
return tf.bitwise.right_shift(x, y)
|
||
|
||
|
||
def _shift_right_arithmetic(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA
|
||
# semantics to return the shift by the max value (x_bits - 1).
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
x_bits = 8 * x.dtype.size
|
||
clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1)
|
||
return _shift_right_arithmetic_raw(x, clamp_y)
|
||
|
||
|
||
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic
|
||
|
||
|
||
def _shift_right_logical_raw(x, y):
|
||
if x.dtype.is_unsigned:
|
||
return tf.bitwise.right_shift(x, y)
|
||
else:
|
||
assert x.dtype == y.dtype
|
||
orig_dtype = x.dtype
|
||
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
|
||
x = tf.cast(x, unsigned_dtype)
|
||
y = tf.cast(y, unsigned_dtype)
|
||
res = tf.bitwise.right_shift(x, y)
|
||
return tf.cast(res, orig_dtype)
|
||
|
||
|
||
def _shift_right_logical(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA semantics
|
||
# to return 0.
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
return tf.where(
|
||
_shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x))
|
||
|
||
|
||
tf_impl[lax.shift_right_logical_p] = _shift_right_logical
|
||
|
||
|
||
def _shift_left(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA semantics
|
||
# to return 0.
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
return tf.where(
|
||
_shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x))
|
||
|
||
|
||
tf_impl[lax.shift_left_p] = _shift_left
|
||
|
||
|
||
def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal:
|
||
# Return the TF expression for when y is within bounds (0 <= y < |x|)
|
||
x_bits = 8 * x.dtype.size
|
||
# TF does not have comparisons for uint16 and uint32 (despite what the
|
||
# documentation says)
|
||
y_comp = tf.cast(
|
||
y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y
|
||
y_lt_x_bits = tf.math.less(y_comp, x_bits)
|
||
y_ge_0 = tf.math.greater_equal(y_comp, 0)
|
||
return tf.logical_and(y_lt_x_bits, y_ge_0)
|
||
|
||
|
||
def _not(x):
|
||
"""Computes bitwise not with support for booleans.
|
||
|
||
Numpy and JAX support bitwise not for booleans by applying a logical not!
|
||
This means that applying bitwise_not yields an unexpected result:
|
||
jnp.bitwise_not(jnp.array([True, False]))
|
||
>> DeviceArray([False, True], dtype=bool)
|
||
|
||
if you assume that booleans are simply casted to integers.
|
||
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
|
||
>> DeviceArray([True, True], dtype=bool)
|
||
"""
|
||
if x.dtype == tf.bool:
|
||
return tf.logical_not(x)
|
||
else:
|
||
return tf.bitwise.invert(x)
|
||
|
||
|
||
tf_impl[lax.not_p] = _not
|
||
|
||
|
||
def handle_boolean_args(f, argnums: Sequence[int], boolean_f=None):
|
||
"""Computes functions with some bool args and bool results using int8.
|
||
|
||
This is needed because some TF ops do not work for bool args, e.g.,
|
||
inequalities, min/max.
|
||
|
||
Args:
|
||
f: a TF callable to wrap. It will be called with non-boolean arguments.
|
||
argnums: the positional arguments that may be booleans.
|
||
boolean_f: [Optional] a TF callable compatible with boolean
|
||
arguments.
|
||
|
||
Returns: a TF callable that can take a mix of boolean positional arguments
|
||
(in the positions specified by `argnums`) and some non-boolean positional
|
||
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
|
||
it calls `boolean_f` if defined. Otherwise, casts the boolean
|
||
arguments to `int8`, calls `f`, then casts the result to `bool`.
|
||
"""
|
||
argnums = tf.nest.flatten(argnums)
|
||
|
||
def wrapper(*args: TfVal, **kwargs):
|
||
argnum_types = {args[i].dtype for i in argnums}
|
||
if tf.bool not in argnum_types:
|
||
return f(*args, **kwargs)
|
||
else:
|
||
# All argnums should be boolean
|
||
assert len(argnum_types) == 1, argnum_types
|
||
if boolean_f != None:
|
||
return boolean_f(*args, **kwargs)
|
||
else:
|
||
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
|
||
for i, a in enumerate(args)]
|
||
if "_in_avals" in kwargs:
|
||
|
||
def cast_aval(aval):
|
||
assert aval.dtype == np.bool_
|
||
return core.ShapedArray(aval.shape, np.int8)
|
||
|
||
_in_avals_cast = [
|
||
cast_aval(aval) if i in argnums else aval
|
||
for i, aval in enumerate(kwargs["_in_avals"])
|
||
]
|
||
_out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"])
|
||
kwargs = dict(
|
||
kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast)
|
||
out = f(*args_cast, **kwargs)
|
||
return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out)
|
||
|
||
return wrapper
|
||
|
||
|
||
tf_impl[lax.or_p] = handle_boolean_args(tf.bitwise.bitwise_or, argnums=(0, 1), boolean_f=tf.logical_or)
|
||
tf_impl[lax.and_p] = handle_boolean_args(tf.bitwise.bitwise_and, argnums=(0, 1), boolean_f=tf.logical_and)
|
||
tf_impl[lax.xor_p] = handle_boolean_args(tf.bitwise.bitwise_xor, argnums=(0, 1), boolean_f=tf.math.logical_xor)
|
||
|
||
tf_impl[lax.eq_p] = tf.math.equal
|
||
tf_impl[lax.ne_p] = tf.math.not_equal
|
||
|
||
boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T
|
||
boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T
|
||
boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T
|
||
boolean_less_or_equal = lambda x, y: tf.logical_not(boolean_greater(x,y)) # All cases except T,F
|
||
|
||
tf_impl[lax.gt_p] = handle_boolean_args(tf.math.greater, argnums=(0, 1), boolean_f=boolean_greater)
|
||
tf_impl[lax.lt_p] = handle_boolean_args(tf.math.less, argnums=(0, 1), boolean_f=boolean_less)
|
||
tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal)
|
||
tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal)
|
||
|
||
tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky
|
||
|
||
|
||
def _convert_element_type(operand, *, new_dtype, weak_type=False):
|
||
old_dtype = operand.dtype.as_numpy_dtype
|
||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||
operand = tf.math.real(operand)
|
||
if (dtypes.issubdtype(old_dtype, np.floating) and
|
||
not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype(
|
||
new_dtype, np.complexfloating) or new_dtype == np.bool_)):
|
||
sign = _sign(operand)
|
||
operand = sign * tf.math.floor(sign * operand)
|
||
return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype))
|
||
|
||
|
||
tf_impl[lax.convert_element_type_p] = _convert_element_type
|
||
|
||
|
||
def _bitcast_convert_type(operand, new_dtype):
|
||
if operand.dtype == new_dtype:
|
||
return operand
|
||
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
|
||
|
||
|
||
tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
|
||
|
||
|
||
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
|
||
# The below permits mirroring the behavior of JAX when maxval < minval
|
||
op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype)
|
||
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
|
||
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
|
||
return tf.clip_by_value(operand, minval, maxval)
|
||
|
||
|
||
tf_impl_with_avals[lax.clamp_p] = _clamp
|
||
|
||
|
||
def _concatenate(*operands, dimension):
|
||
return tf.concat(operands, axis=dimension)
|
||
|
||
|
||
tf_impl[lax.concatenate_p] = _concatenate
|
||
|
||
|
||
def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
|
||
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
|
||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||
proto = xla_data_pb2.ConvolutionDimensionNumbers()
|
||
proto.input_batch_dimension = lhs_spec[0]
|
||
proto.input_feature_dimension = lhs_spec[1]
|
||
proto.output_batch_dimension = out_spec[0]
|
||
proto.output_feature_dimension = out_spec[1]
|
||
proto.kernel_output_feature_dimension = rhs_spec[0]
|
||
proto.kernel_input_feature_dimension = rhs_spec[1]
|
||
proto.input_spatial_dimensions.extend(lhs_spec[2:])
|
||
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
|
||
proto.output_spatial_dimensions.extend(out_spec[2:])
|
||
return proto
|
||
|
||
|
||
def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
|
||
PrecisionType]]):
|
||
"""Convert an integer to an XLA.PrecisionConfig."""
|
||
if precision is None:
|
||
return None
|
||
|
||
proto = xla_data_pb2.PrecisionConfig()
|
||
proto.operand_precision.append(int(precision[0]))
|
||
proto.operand_precision.append(int(precision[1]))
|
||
return proto
|
||
|
||
|
||
def _conv_general_dilated(lhs, rhs, *,
|
||
window_strides, padding, lhs_dilation,
|
||
rhs_dilation,
|
||
dimension_numbers: lax.ConvDimensionNumbers,
|
||
feature_group_count: int,
|
||
batch_group_count: int,
|
||
lhs_shape: Sequence[int],
|
||
rhs_shape: Sequence[int],
|
||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||
preferred_element_type: Optional[DType],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||
out_tf_shape = _aval_to_tf_shape(_out_aval)
|
||
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
|
||
precision_config_proto = _precision_config_proto(precision)
|
||
|
||
def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]):
|
||
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
|
||
if tf_version >= (2, 8):
|
||
# TODO(necula): remove when 2.8.0 is the stable TF version (and supports
|
||
# batch_group_count.
|
||
out = tfxla.conv(
|
||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||
dnums_proto,
|
||
feature_group_count=feature_group_count,
|
||
batch_group_count=batch_group_count,
|
||
precision_config=precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
else:
|
||
if batch_group_count != 1:
|
||
raise ValueError(
|
||
"The batch_group_count parameter for conv requires TF version "
|
||
"at least 2.8.0. You may want to use tf-nightly.")
|
||
out = tfxla.conv(
|
||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||
dnums_proto,
|
||
feature_group_count=feature_group_count,
|
||
precision_config=precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
# TODO: implement shape inference for XlaConv
|
||
out.set_shape(out_tf_shape)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
# Follow the lowering for complex convolutions from
|
||
# lax._conv_general_dilated_translation. We can use the same conversion on all
|
||
# platforms because on XLA:TPU the compiler does the same as a rewrite.
|
||
preferred_float_et: Optional[Any]
|
||
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
|
||
if preferred_element_type is not None:
|
||
# Convert complex dtype to types used for real and imaginary parts
|
||
assert np.issubdtype(preferred_element_type, np.complexfloating)
|
||
preferred_float_et = (
|
||
np.float64 if preferred_element_type == np.complex128 else np.float32)
|
||
else:
|
||
preferred_float_et = None
|
||
lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs)
|
||
rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs)
|
||
k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et)
|
||
k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real),
|
||
preferred_float_et)
|
||
k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et)
|
||
return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2))
|
||
else:
|
||
return gen_conv(lhs, rhs, preferred_element_type)
|
||
|
||
|
||
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||
|
||
|
||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||
preferred_element_type: Optional[DType],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
||
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||
precision_config_proto = _precision_config_proto(precision)
|
||
res = tfxla.dot_general(
|
||
lhs,
|
||
rhs,
|
||
dnums_proto,
|
||
precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
res = tf.stop_gradient(res) # See #7839
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.dot_general_p] = _dot_general
|
||
|
||
|
||
def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# for i in range(len(operand.shape)):
|
||
# result.shape[bcast_dims[i]] <- operand.shape[i]
|
||
# bcast_dims must be strictly increasing.
|
||
# len(bcast_dims) == len(operand.shape)
|
||
op_shape = _in_avals[0].shape
|
||
dtype = _in_avals[0].dtype
|
||
add_1s_shape = [1] * len(shape)
|
||
for i, broadcast_dim_i in enumerate(broadcast_dimensions):
|
||
add_1s_shape[broadcast_dim_i] = op_shape[i]
|
||
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype))
|
||
return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype))
|
||
|
||
|
||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||
|
||
|
||
def _empty(*, dtype):
|
||
if core.is_opaque_dtype(dtype):
|
||
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
|
||
return tf.constant(np.array(0, dtype=dtype))
|
||
|
||
|
||
tf_impl[lax_internal.empty_p] = _empty
|
||
|
||
|
||
def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval):
|
||
if dimensions is None:
|
||
dimensions = tf.range(tf.rank(operand))
|
||
new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype)
|
||
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
|
||
|
||
|
||
tf_impl_with_avals[lax.reshape_p] = _reshape
|
||
|
||
|
||
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
|
||
op_aval = _jax_physical_aval(_in_avals[0])
|
||
op_shape = op_aval.shape
|
||
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
|
||
new_shape_tf = _eval_shape(new_shape, op_aval.dtype)
|
||
return tf.reshape(operand, new_shape_tf)
|
||
|
||
|
||
tf_impl_with_avals[lax.squeeze_p] = _squeeze
|
||
|
||
|
||
def _pad(operand, padding_value, *, padding_config,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
low, high, interior = util.unzip3(padding_config)
|
||
out = tfxla.pad(operand, padding_value, low, high, interior)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.pad_p] = _pad
|
||
|
||
|
||
def _rev(operand, *, dimensions):
|
||
return tf.reverse(operand, dimensions)
|
||
|
||
|
||
tf_impl[lax.rev_p] = _rev
|
||
|
||
|
||
def _where(which, *cases):
|
||
if which.dtype == tf.bool:
|
||
assert len(cases) <= 2
|
||
return cases if len(cases) == 1 else tf.where(which, cases[1], cases[0])
|
||
|
||
def _select(offset, cases):
|
||
assert len(cases) > 0
|
||
if len(cases) == 1:
|
||
return cases[0]
|
||
mid = len(cases) // 2
|
||
return tf.where(tf.less(which, offset + mid),
|
||
_select(offset, cases[:mid]),
|
||
_select(mid, cases[mid:]))
|
||
|
||
return _select(0, cases)
|
||
|
||
|
||
tf_impl[lax.select_n_p] = _where
|
||
|
||
|
||
def _transpose(operand, *, permutation):
|
||
return tf.transpose(operand, perm=permutation)
|
||
|
||
|
||
tf_impl[lax.transpose_p] = _transpose
|
||
|
||
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
||
|
||
# reduce_sum and reduce_prod are not supported for bool
|
||
tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum)
|
||
tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod)
|
||
tf_impl[lax.reduce_max_p] = handle_boolean_args(
|
||
axes_to_axis(tf.reduce_max), argnums=[0],
|
||
boolean_f=axes_to_axis(tf.reduce_any)) # Max is T if any one is T
|
||
tf_impl[lax.reduce_min_p] = handle_boolean_args(
|
||
axes_to_axis(tf.reduce_min), argnums=[0],
|
||
boolean_f=axes_to_axis(tf.reduce_all)) # Min is F if not all are T
|
||
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||
|
||
|
||
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
|
||
index_dtype: DType,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Follow the JAX implementation, using a XlaReduce with a custom comparator
|
||
if is_min:
|
||
extra_name_stack = "argmin"
|
||
value_comparator = lax.lt
|
||
get_identity = lax_internal._get_min_identity
|
||
else:
|
||
extra_name_stack = "argmax"
|
||
value_comparator = lax.gt
|
||
get_identity = lax_internal._get_max_identity
|
||
|
||
res = _convert_jax_impl(
|
||
partial(lax_internal._compute_argminmax, value_comparator, get_identity),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)(
|
||
operand,
|
||
index_dtype=index_dtype,
|
||
axes=axes,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True)
|
||
tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False)
|
||
|
||
|
||
_add_fn = tf.function(_add, autograph=False)
|
||
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
|
||
|
||
|
||
def _select_and_gather_add(
|
||
tangents: TfVal, operand: TfVal, select_prim: core.Primitive,
|
||
window_dimensions: Sequence[int], window_strides: Sequence[int],
|
||
base_dilation: Sequence[int], window_dilation: Sequence[int],
|
||
padding: Sequence[Tuple[int, int]], _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Note: this function follows the pattern in
|
||
# jax.lax._select_and_gather_add_translation.
|
||
dtype = operand.dtype
|
||
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits
|
||
|
||
# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
|
||
# we thus intend to let the code throw a different exception on this platform.
|
||
max_bits = 64
|
||
|
||
assert nbits <= max_bits
|
||
double_word_reduction = nbits * 2 <= max_bits
|
||
|
||
const = lambda dtype, x: tf.constant(np.array(x), dtype)
|
||
|
||
if double_word_reduction:
|
||
word_dtype = lax_internal._UINT_DTYPES[nbits]
|
||
double_word_dtype = lax_internal._UINT_DTYPES[nbits * 2]
|
||
|
||
# Packs two values into a tuple.
|
||
def pack(a, b):
|
||
a = _bitcast_convert_type(a, word_dtype)
|
||
b = _bitcast_convert_type(b, word_dtype)
|
||
a = _convert_element_type(a, new_dtype=double_word_dtype)
|
||
b = _convert_element_type(b, new_dtype=double_word_dtype)
|
||
a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits))
|
||
return tf.bitwise.bitwise_or(a, b)
|
||
|
||
# Unpacks the first element of a tuple.
|
||
def fst(t):
|
||
assert t.dtype == double_word_dtype
|
||
st = _shift_right_logical(t, const(double_word_dtype, nbits))
|
||
return _bitcast_convert_type(
|
||
_convert_element_type(st, new_dtype=word_dtype), dtype)
|
||
|
||
# Unpacks the second element of a tuple.
|
||
def snd(t):
|
||
return _bitcast_convert_type(
|
||
_convert_element_type(t, new_dtype=word_dtype), dtype)
|
||
|
||
else:
|
||
raise NotImplementedError(
|
||
f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits."
|
||
)
|
||
|
||
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
||
|
||
def reducer(x, y):
|
||
which = tf_impl[select_prim]
|
||
return tf_impl[lax.select_n_p](which(fst(x), fst(y)), y, x)
|
||
|
||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||
init_identity = lambda x: pack(const(dtype, init), const(dtype, 0))
|
||
|
||
out = _specialized_reduce_window(
|
||
reducer,
|
||
init_identity,
|
||
pack(operand, tangents),
|
||
window_dimensions=window_dimensions,
|
||
window_strides=window_strides,
|
||
padding=padding,
|
||
base_dilation=base_dilation,
|
||
window_dilation=window_dilation,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
return snd(out)
|
||
|
||
|
||
tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
|
||
|
||
|
||
def _get_shape_from_tensor_or_array(x):
|
||
if isinstance(x.shape, tf.TensorShape):
|
||
return tuple(x.shape.as_list())
|
||
return tuple(x.shape)
|
||
|
||
|
||
def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||
window_strides, padding, base_dilation,
|
||
window_dilation, _in_avals, _out_aval):
|
||
o_spec = tf.TensorSpec((), dtype=operand.dtype)
|
||
reducer_fn = tf.function(
|
||
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
|
||
|
||
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
|
||
init_val = tf.constant(init_val, operand.dtype)
|
||
out = tfxla.reduce_window(
|
||
operand,
|
||
init_val,
|
||
reducer_fn,
|
||
window_dimensions,
|
||
window_strides,
|
||
base_dilations=base_dilation,
|
||
window_dilations=window_dilation,
|
||
padding=padding)
|
||
# TODO: implement shape inference for XlaReduceWindow
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
||
window_strides, padding, base_dilation, window_dilation,
|
||
_in_avals, _out_aval):
|
||
"""TensorFlow implementation of reduce_window.
|
||
|
||
Args:
|
||
operands: N dimensional arrays containing elements of type T
|
||
init_values: starting values of the reduction
|
||
jaxpr: the jaxpr corresponding to the reduction function
|
||
consts: the constants associated with jaxpr.
|
||
window_dimensions: array of integers for window dimension values
|
||
window_strides: array of integers for window stride values
|
||
padding: array of pairs of integers for padding values
|
||
base_dilation: array of integers for base dilation values
|
||
window_dilation: array of integers for window dilation values
|
||
|
||
Returns:
|
||
The reduced operand.
|
||
"""
|
||
assert len(consts) == 0, "Reduction computation cannot have constants"
|
||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||
|
||
if len(operands) != 1:
|
||
raise NotImplementedError("jax2tf does not support variadic reduce_window")
|
||
|
||
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
|
||
return res
|
||
|
||
return (_common_reduce_window(operands[0], init_values[0], reducer,
|
||
window_dimensions, window_strides, padding,
|
||
base_dilation, window_dilation, _in_avals,
|
||
_out_aval[0]),)
|
||
|
||
|
||
|
||
def _specialized_reduce_window(reducer,
|
||
identity,
|
||
operand,
|
||
*,
|
||
window_dimensions,
|
||
window_strides,
|
||
padding,
|
||
base_dilation,
|
||
window_dilation,
|
||
_in_avals,
|
||
_out_aval,
|
||
name=None):
|
||
"""Wraps the TensorFlow reduce window operation based on a reducer and an
|
||
|
||
identity function defining the initial value of the reduction depending on
|
||
the dtype of the operand.
|
||
|
||
Args:
|
||
reducer: reduction function of type TfVal -> TfVal -> TfVal
|
||
identity: function that takes a TensorFlow dtype as a parameter and returns
|
||
the starting value of the reduction.
|
||
operand: N dimensional array containing elements of type T
|
||
window_dimensions: array of integers for window dimension values
|
||
window_strides: array of integers for window stride values
|
||
padding: array of pairs of integers for padding values
|
||
base_dilation: array of integers for base dilation values
|
||
window_dilation: array of integers for window dilation values
|
||
name: the name of the specialized reduce window primitive for which this
|
||
conversion function is called. This information may help to choose a
|
||
different conversion path (optional)
|
||
|
||
Returns:
|
||
The reduced operand.
|
||
"""
|
||
return _common_reduce_window(operand, identity(operand.dtype), reducer,
|
||
window_dimensions, window_strides, padding,
|
||
base_dilation, window_dilation, _in_avals,
|
||
_out_aval)
|
||
|
||
|
||
def _get_max_identity(tf_dtype):
|
||
numpy_tf_dtype = tf_dtype.as_numpy_dtype
|
||
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
|
||
return numpy_tf_dtype(-np.inf)
|
||
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
|
||
return dtypes.iinfo(numpy_tf_dtype).min
|
||
else:
|
||
assert dtypes.issubdtype(
|
||
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity")
|
||
return False
|
||
|
||
|
||
def _get_min_identity(tf_dtype):
|
||
numpy_tf_dtype = tf_dtype.as_numpy_dtype
|
||
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
|
||
return numpy_tf_dtype(np.inf)
|
||
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
|
||
return dtypes.iinfo(numpy_tf_dtype).max
|
||
else:
|
||
assert dtypes.issubdtype(
|
||
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity")
|
||
return True
|
||
|
||
|
||
# pylint: disable=protected-access
|
||
tf_impl_with_avals[lax.reduce_window_sum_p] = (
|
||
partial(_specialized_reduce_window, _add, lambda x: 0,
|
||
name="reduce_window_sum"))
|
||
tf_impl_with_avals[lax.reduce_window_min_p] = (
|
||
partial(_specialized_reduce_window,
|
||
partial(_minmax_scalar, is_min=True),
|
||
_get_min_identity,
|
||
name="reduce_window_min"))
|
||
tf_impl_with_avals[lax.reduce_window_max_p] = (
|
||
partial(_specialized_reduce_window,
|
||
partial(_minmax_scalar, is_min=False),
|
||
_get_max_identity,
|
||
name="reduce_window_max"))
|
||
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
|
||
# pylint: enable=protected-access
|
||
|
||
def _reduce(*operands: TfVal,
|
||
computation: Callable,
|
||
jaxpr: core.Jaxpr,
|
||
consts: Sequence[Any],
|
||
dimensions: Sequence[int],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray) -> Sequence[TfVal]:
|
||
del computation
|
||
assert not consts
|
||
assert len(operands) % 2 == 0
|
||
# operands: op1, op2, ..., init_val1, init_val2, ...
|
||
# reducer takes op1[i], op2[i], ..., init_val1, init_val2, ...
|
||
nr_operands = len(operands) // 2
|
||
init_vals = operands[nr_operands:]
|
||
operands = operands[0:nr_operands]
|
||
|
||
reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2)
|
||
|
||
def reducer_computation(*args: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||
res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None)
|
||
return res
|
||
|
||
xla_reducer_computation = (
|
||
tf.function(reducer_computation,
|
||
autograph=False).get_concrete_function(*reducer_arg_spec))
|
||
|
||
outs = tfxla.variadic_reduce(operands, init_vals,
|
||
dimensions_to_reduce=dimensions,
|
||
reducer=xla_reducer_computation)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
outs = tuple(tf.stop_gradient(out) for out in outs) # See #7839
|
||
return outs
|
||
|
||
tf_impl_with_avals[lax.reduce_p] = _reduce
|
||
|
||
|
||
# We use lax.cumred_reduce_window_impl to convert cummax,
|
||
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
|
||
# O(n^2) on other backends. This may be implemented using associative_scan
|
||
# instead to favor different backends.
|
||
def _cumred(lax_reduce_fn: Callable,
|
||
lax_reduce_window_fn: Callable,
|
||
extra_name_stack: str):
|
||
if config.jax2tf_associative_scan_reductions:
|
||
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
|
||
lax_reduce_fn),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)
|
||
else:
|
||
return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl,
|
||
lax_reduce_window_fn),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)
|
||
|
||
|
||
tf_impl_with_avals[lax.cummax_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_max,
|
||
lax_reduce_fn=lax.max,
|
||
extra_name_stack="cummax")
|
||
tf_impl_with_avals[lax.cummin_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_min,
|
||
lax_reduce_fn=lax.min,
|
||
extra_name_stack="cummin")
|
||
tf_impl_with_avals[lax.cumlogsumexp_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_logaddexp,
|
||
lax_reduce_fn=logaddexp,
|
||
extra_name_stack="cumlogsumexp")
|
||
tf_impl_with_avals[lax.cumsum_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_sum,
|
||
lax_reduce_fn=lax.add,
|
||
extra_name_stack="cumsum")
|
||
tf_impl_with_avals[lax.cumprod_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_prod,
|
||
lax_reduce_fn=lax.mul,
|
||
extra_name_stack="cumprod")
|
||
|
||
|
||
def _select_and_scatter(operand, source, init_value, select_jaxpr,
|
||
select_consts, scatter_jaxpr, scatter_consts,
|
||
window_dimensions, window_strides, padding):
|
||
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")
|
||
|
||
|
||
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
|
||
|
||
|
||
@partial(handle_boolean_args, argnums=(0, 1))
|
||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||
window_strides, padding, _in_avals, _out_aval):
|
||
init_value = tf.zeros((), operand.dtype)
|
||
select_fn = (
|
||
tf.function(tf_impl[select_prim], autograph=False).get_concrete_function(
|
||
init_value, init_value))
|
||
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
|
||
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
|
||
padding, source, init_value, select_fn,
|
||
scatter_fn)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
|
||
|
||
|
||
def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
|
||
|
||
def impl_wrapper(seeds: TfVal, *, impl):
|
||
return prng.random_seed_impl_base(seeds, impl=impl)
|
||
|
||
converted_impl = _convert_jax_impl(
|
||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||
extra_name_stack="random_seed")
|
||
return converted_impl(
|
||
seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl
|
||
|
||
|
||
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
|
||
keys_aval, = _in_avals
|
||
|
||
def impl_wrapper(keys: TfVal, *, count):
|
||
return prng.random_split_impl_base(
|
||
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
|
||
|
||
converted_impl = _convert_jax_impl(
|
||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||
extra_name_stack="random_split")
|
||
return converted_impl(
|
||
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
tf_impl_with_avals[prng.random_split_p] = _random_split_impl
|
||
|
||
|
||
def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval):
|
||
keys_aval, _ = _in_avals
|
||
|
||
def impl_wrapper(keys: TfVal, msgs: TfVal):
|
||
return prng.random_fold_in_impl_base(
|
||
keys_aval.dtype.impl, keys, msgs, keys_aval.shape)
|
||
|
||
converted_impl = _convert_jax_impl(
|
||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||
extra_name_stack="random_fold_in")
|
||
return converted_impl(
|
||
keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
tf_impl_with_avals[prng.random_fold_in_p] = _random_fold_in_impl
|
||
|
||
|
||
def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval):
|
||
keys_aval, = _in_avals
|
||
|
||
def impl_wrapper(keys: TfVal, **kwargs):
|
||
return prng.random_bits_impl_base(
|
||
keys_aval.dtype.impl, keys, keys_aval.ndim,
|
||
bit_width=bit_width, shape=shape)
|
||
|
||
converted_impl = _convert_jax_impl(
|
||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||
extra_name_stack="random_bits")
|
||
return converted_impl(keys, bit_width=bit_width, shape=shape,
|
||
_in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
tf_impl_with_avals[prng.random_bits_p] = _random_bits_impl
|
||
|
||
|
||
def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval):
|
||
return base_arr
|
||
|
||
tf_impl_with_avals[prng.random_wrap_p] = _random_wrap_impl
|
||
|
||
|
||
def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval):
|
||
return keys
|
||
|
||
tf_impl_with_avals[prng.random_unwrap_p] = _random_unwrap_impl
|
||
|
||
|
||
|
||
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
|
||
res = _convert_jax_impl(
|
||
partial(prng._threefry2x32_lowering, use_rolled_loops=False),
|
||
multiple_results=True, extra_name_stack="threefry")(
|
||
*args, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[prng.threefry2x32_p] = _threefry2x32_jax_impl
|
||
|
||
# Use the vmap implementation, otherwise on TPU the performance is really bad
|
||
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
|
||
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
|
||
partial(random_internal._gamma_impl, use_vmap=True),
|
||
multiple_results=False, extra_name_stack="random_gamma")
|
||
|
||
|
||
def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]:
|
||
is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32)
|
||
if is_uint32_key:
|
||
key = tf.reshape(key, (2, 2))
|
||
key = tfxla.bitcast_convert_type(key, _to_tf_dtype(jnp.uint64))
|
||
shape_tf = _eval_shape(shape)
|
||
# JAX uses XLA algorithm enums; tfxla uses tf.random.Algorithm
|
||
if algorithm == lax.RandomAlgorithm.RNG_THREE_FRY:
|
||
algorithm_tf = tf.random.Algorithm.THREEFRY
|
||
elif algorithm == lax.RandomAlgorithm.RNG_PHILOX:
|
||
algorithm_tf = tf.random.Algorithm.PHILOX
|
||
elif algorithm == lax.RandomAlgorithm.RNG_DEFAULT:
|
||
algorithm_tf = tf.random.Algorithm.AUTO_SELECT
|
||
else:
|
||
assert False
|
||
(new_key, res) = tfxla.rng_bit_generator(algorithm_tf.value, key, shape_tf,
|
||
dtype=_to_tf_dtype(dtype))
|
||
if is_uint32_key:
|
||
new_key = tfxla.bitcast_convert_type(new_key, _to_tf_dtype(jnp.uint32))
|
||
new_key = tf.reshape(new_key, (4,))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
# See #7839
|
||
new_key = tf.stop_gradient(new_key)
|
||
res = tf.stop_gradient(res)
|
||
return new_key, res
|
||
|
||
|
||
tf_impl[lax.rng_bit_generator_p] = _rng_bit_generator
|
||
|
||
|
||
def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal:
|
||
shape_tf = _eval_shape(shape)
|
||
return tf.random.uniform(shape_tf, minval=minval, maxval=maxval, dtype=minval.dtype)
|
||
|
||
tf_impl[lax.rng_uniform_p] = _rng_uniform
|
||
|
||
|
||
def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||
proto = xla_data_pb2.GatherDimensionNumbers()
|
||
proto.offset_dims.extend(dimension_numbers.offset_dims)
|
||
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
|
||
proto.start_index_map.extend(dimension_numbers.start_index_map)
|
||
assert indices_shape
|
||
proto.index_vector_dim = len(indices_shape) - 1
|
||
return proto
|
||
|
||
|
||
def _maybe_cast_to_int64(x: TfVal) -> TfVal:
|
||
if x.dtype != tf.int32 and x.dtype != tf.int64:
|
||
return tf.cast(x, tf.int64)
|
||
return x
|
||
|
||
|
||
@partial(handle_boolean_args, argnums=[0])
|
||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shape,
|
||
indices_are_sorted, unique_indices, mode, fill_value,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Tensorflow implementation of gather."""
|
||
if mode == lax.GatherScatterMode.FILL_OR_DROP:
|
||
gather_fill_fn = _convert_jax_impl(lax_slicing._gather_fill,
|
||
multiple_results=False)
|
||
return gather_fill_fn(
|
||
operand, start_indices, dimension_numbers=dimension_numbers,
|
||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
|
||
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
start_indices = _maybe_cast_to_int64(start_indices)
|
||
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
||
slice_sizes_tf = _eval_shape(slice_sizes, _in_avals[0].dtype)
|
||
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
|
||
indices_are_sorted)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.gather_p] = _gather
|
||
|
||
|
||
def _slice(operand, start_indices, limit_indices, strides, _in_avals,
|
||
_out_aval):
|
||
if strides is None:
|
||
strides = [1] * len(start_indices)
|
||
slices = tuple(
|
||
map(slice, _eval_shape(start_indices), _eval_shape(limit_indices),
|
||
_eval_shape(strides)))
|
||
out = operand[slices]
|
||
# TODO(b/184503314): improve shape inference for __getitem__
|
||
# E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.slice_p] = _slice
|
||
|
||
|
||
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
|
||
slice_sizes_tf = _eval_shape(slice_sizes, dtype=_in_avals[0].dtype)
|
||
|
||
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
res = tf.stop_gradient(res) # See #7839
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
|
||
|
||
|
||
def _dynamic_update_slice(operand, update, *start_indices,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
|
||
out = tfxla.dynamic_update_slice(operand, update, start_indices)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice
|
||
|
||
|
||
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
|
||
proto = xla_data_pb2.ScatterDimensionNumbers()
|
||
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
|
||
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
|
||
proto.scatter_dims_to_operand_dims.extend(
|
||
dimension_numbers.scatter_dims_to_operand_dims)
|
||
assert indices_shape
|
||
proto.index_vector_dim = len(indices_shape) - 1
|
||
return proto
|
||
|
||
|
||
def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
|
||
dimension_numbers, indices_are_sorted, unique_indices, mode,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
del unique_indices
|
||
|
||
if mode == lax.GatherScatterMode.CLIP:
|
||
clip_fn = _convert_jax_impl(lax_slicing._clamp_scatter_indices,
|
||
multiple_results=False)
|
||
scatter_indices = clip_fn(
|
||
operand, scatter_indices, updates, dnums=dimension_numbers,
|
||
_in_avals=_in_avals, _out_aval=_in_avals[1])
|
||
|
||
assert len(update_consts) == 0, "Update computation cannot have constants"
|
||
|
||
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
|
||
|
||
def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
|
||
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
|
||
return res
|
||
|
||
o_spec = tf.TensorSpec((), dtype=operand.dtype)
|
||
xla_update_computation = (
|
||
tf.function(update_computation,
|
||
autograph=False).get_concrete_function(o_spec, o_spec))
|
||
out = tfxla.scatter(
|
||
operand,
|
||
scatter_indices,
|
||
updates,
|
||
xla_update_computation,
|
||
proto,
|
||
indices_are_sorted=indices_are_sorted)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.scatter_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_min_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_max_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
||
|
||
|
||
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||
linear: Sequence[bool]) -> Sequence[TfVal]:
|
||
del linear
|
||
# tf.cond needs lambdas with no arguments.
|
||
branches_tf = [
|
||
partial(_interpret_jaxpr, jaxpr, *operands,
|
||
# Same name stack as the XLA translation of cond_p
|
||
extra_name_stack=f"branch_{i}_fun")
|
||
for jaxpr in branches
|
||
for i, jaxpr in enumerate(branches)
|
||
]
|
||
if config.jax_experimental_name_stack:
|
||
# Same name stack as XLA translation of cond_p
|
||
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
|
||
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]
|
||
branches_tf))
|
||
return tf.switch_case(index, branches_tf)
|
||
|
||
|
||
tf_impl[lax.cond_p] = _cond
|
||
|
||
|
||
def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
|
||
body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
|
||
cond_consts, body_consts, init_carry = util.split_list(
|
||
args, [cond_nconsts, body_nconsts])
|
||
if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined]
|
||
# The conditional is not a scalar, this must be a batched while
|
||
return _batched_cond_while(
|
||
*args,
|
||
cond_nconsts=cond_nconsts,
|
||
cond_jaxpr=cond_jaxpr,
|
||
body_nconsts=body_nconsts,
|
||
body_jaxpr=body_jaxpr)
|
||
|
||
# The conditional must return a single value to TF
|
||
def cond_tf_func(*args: TfVal) -> TfVal:
|
||
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args,
|
||
# Same name stack as the XLA translation of while_p
|
||
extra_name_stack="while/cond")
|
||
return pred
|
||
|
||
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts,
|
||
extra_name_stack="while/body")
|
||
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
|
||
|
||
|
||
def _batched_cond_while(*args: TfVal, cond_nconsts: int,
|
||
cond_jaxpr: core.ClosedJaxpr, body_nconsts: int,
|
||
body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
|
||
"""Interprets a while_loop with a batched condition.
|
||
|
||
A batched while has a conditional that returns a tensor of booleans, and
|
||
a body that returns a list of tensors whose leading dimensions match those
|
||
of the conditional tensor.
|
||
|
||
We need to turn it into a while with scalar boolean conditional. We will
|
||
expand the loop carry to include a prefix with the current tensor boolean
|
||
condition. We prepend to the loop the first calculation of the tensor boolean
|
||
condition. The loop condition will use a "reduce_any" to calculate a scalar
|
||
boolean from the tensor boolean condition. The end of the loop body will
|
||
compute the new carry using a "tf.where", and we compute the new tensor
|
||
boolean condition.
|
||
"""
|
||
cond_consts, body_consts, init_carry = util.split_list(
|
||
args, [cond_nconsts, body_nconsts])
|
||
# Initial computation of batched condition
|
||
init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry,
|
||
extra_name_stack="while/body_pred")
|
||
|
||
def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal:
|
||
pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape))))
|
||
return pred
|
||
|
||
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
|
||
new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts,
|
||
*carry,
|
||
extra_name_stack="while/body")
|
||
# We repeat those carries for which the loop termination condition is false
|
||
def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
|
||
pred_b_bcast = _broadcast_in_dim(
|
||
pred_b,
|
||
shape=_jax_physical_aval(c_aval).shape, # a JAX shape
|
||
broadcast_dimensions=list(range(len(pred_b.shape))),
|
||
_in_avals=cond_jaxpr.out_avals,
|
||
_out_aval=core.ShapedArray(c_aval.shape, np.bool_))
|
||
return tf.where(pred_b_bcast, new_c, c)
|
||
|
||
selected_carry: Sequence[TfVal] = list(map(select_one_carry, new_carry, carry, body_jaxpr.out_avals))
|
||
next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry,
|
||
extra_name_stack="body_pred")
|
||
return (next_pred_b, *selected_carry)
|
||
|
||
_, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func,
|
||
(init_pred_b, *init_carry))
|
||
return res_carry
|
||
|
||
|
||
tf_impl[lax.while_p] = _while
|
||
|
||
# We use the scan impl rule to rewrite in terms of while.
|
||
tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
|
||
lax_control_flow._scan_impl,
|
||
extra_name_stack="scan")
|
||
|
||
tf_impl_with_avals[ad_checkpoint.remat_p] = \
|
||
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
|
||
# TODO: jax2tf cannot discriminate by platform
|
||
is_gpu_platform=False),
|
||
multiple_results=True,
|
||
extra_name_stack="checkpoint")
|
||
|
||
tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x
|
||
|
||
# TODO: Remove once tensorflow is 2.10.0 everywhere.
|
||
if hasattr(tfxla, 'optimization_barrier'):
|
||
tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
|
||
|
||
def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
|
||
# Some types originally incompatible with tf.math.top_k can be promoted
|
||
# to a compatible type without loss of precision.
|
||
def promote_tf_dtype(tf_dtype):
|
||
if tf_dtype in [tf.bool, tf.uint8, tf.uint16]:
|
||
return tf.uint32
|
||
if tf_dtype in [tf.int8, tf.int16]:
|
||
return tf.int32
|
||
if tf_dtype is tf.float16:
|
||
return tf.float32
|
||
return None
|
||
|
||
conversion_dtype = promote_tf_dtype(operand.dtype)
|
||
if conversion_dtype:
|
||
values, indices = tf.math.top_k(
|
||
tf.dtypes.cast(operand, conversion_dtype), k=k, sorted=True)
|
||
return tf.dtypes.cast(values, operand.dtype), indices
|
||
else:
|
||
return tf.math.top_k(operand, k=k, sorted=True)
|
||
|
||
|
||
tf_impl[lax.top_k_p] = _top_k
|
||
|
||
|
||
def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int,
|
||
recall_target: float, is_max_k: bool,
|
||
reduction_input_size_override: int,
|
||
aggregate_to_topk: bool) -> Tuple[TfVal, TfVal]:
|
||
if is_max_k:
|
||
return tf.math.approx_max_k(operand, k, reduction_dimension, recall_target,
|
||
reduction_input_size_override,
|
||
aggregate_to_topk)
|
||
else:
|
||
return tf.math.approx_min_k(operand, k, reduction_dimension, recall_target,
|
||
reduction_input_size_override,
|
||
aggregate_to_topk)
|
||
|
||
|
||
tf_impl[lax.approx_top_k_p] = _approx_top_k
|
||
|
||
|
||
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||
num_keys: int) -> Tuple[TfVal, ...]:
|
||
assert 1 <= num_keys <= len(operands)
|
||
assert 0 <= dimension < len(
|
||
operands[0].shape
|
||
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
|
||
|
||
comparator_spec: List[tf.TensorSpec] = []
|
||
comparator_jax_in_avals: List[core.ShapedArray] = []
|
||
for op in operands:
|
||
o_spec = tf.TensorSpec((), dtype=op.dtype)
|
||
comparator_spec.extend([o_spec, o_spec])
|
||
o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype))
|
||
comparator_jax_in_avals.extend([o_aval, o_aval])
|
||
|
||
# Use the same comparator that JAX uses when compiling to XLA, to get the
|
||
# proper NaN/Inf total order, and the lexicographic ordering.
|
||
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
|
||
# corresponding to two scalars from operand[k].
|
||
def lexicographic_comparator(*tf_args: TfVal) -> TfVal:
|
||
return _convert_jax_impl(
|
||
lax_internal._sort_lt_comparator, multiple_results=False)(
|
||
*tf_args,
|
||
_in_avals=comparator_jax_in_avals,
|
||
_out_aval=core.ShapedArray((), np.bool_),
|
||
num_keys=num_keys)
|
||
|
||
xla_comparator_computation = (
|
||
tf.function(lexicographic_comparator,
|
||
autograph=False).get_concrete_function(*comparator_spec))
|
||
results = tfxla.variadic_sort(
|
||
operands,
|
||
dimension=dimension,
|
||
is_stable=is_stable,
|
||
comparator=xla_comparator_computation)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
results = tuple(tf.stop_gradient(out) for out in results) # See #7839
|
||
return results
|
||
|
||
|
||
tf_impl[lax.sort_p] = _sort
|
||
|
||
|
||
def _fft(x, fft_type, fft_lengths):
|
||
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
|
||
if fft_type == IRFFT:
|
||
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
|
||
else:
|
||
expected_lengths = x.shape[-len(fft_lengths):]
|
||
if expected_lengths != fft_lengths:
|
||
raise NotImplementedError(
|
||
f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of "
|
||
f"array with shape={x.shape}.")
|
||
tf_funcs = {
|
||
FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
|
||
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
|
||
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
|
||
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
|
||
}
|
||
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
|
||
|
||
|
||
tf_impl[lax.fft_p] = _fft
|
||
|
||
|
||
def _qr(operand, full_matrices):
|
||
return tf.linalg.qr(operand, full_matrices=full_matrices)
|
||
|
||
|
||
tf_impl[lax.linalg.qr_p] = _qr
|
||
|
||
|
||
def _svd(operand, full_matrices, compute_uv):
|
||
result = tf.linalg.svd(operand, full_matrices, compute_uv)
|
||
if not compute_uv:
|
||
return result,
|
||
s, u, v = result
|
||
return s, u, tf.linalg.adjoint(v)
|
||
|
||
|
||
tf_impl[lax.linalg.svd_p] = _svd
|
||
|
||
|
||
def _eig(operand: TfVal, compute_left_eigenvectors: bool,
|
||
compute_right_eigenvectors: bool):
|
||
if compute_left_eigenvectors and compute_right_eigenvectors:
|
||
# TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to
|
||
# sort the left eigenvectors in the right order. The jax.numpy.linalg API
|
||
# suggests to me that left eigenvectors are anyway seldom used, so I
|
||
# think it is acceptable to leave as unimplemented for now.
|
||
msg = ("Conversion of eig is not implemented when both "
|
||
"compute_left_eigenvectors and compute_right_eigenvectors are set "
|
||
"to True.")
|
||
raise NotImplementedError(msg)
|
||
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
|
||
return tuple([tf.linalg.eigvals(operand)])
|
||
elif compute_right_eigenvectors:
|
||
return tuple(tf.linalg.eig(operand))
|
||
else: # compute_left_eigenvectors == True
|
||
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
|
||
wHH = tf.math.conj(wH)
|
||
return tuple([wHH, vl])
|
||
|
||
|
||
tf_impl[lax.linalg.eig_p] = _eig
|
||
|
||
|
||
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
|
||
_out_aval):
|
||
del sort_eigenvalues
|
||
if operand.shape[-1] == 0:
|
||
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
|
||
else:
|
||
if not lower:
|
||
operand = tf.linalg.adjoint(operand)
|
||
w, v = tf.linalg.eigh(operand)
|
||
cast_type = {
|
||
tf.complex64: tf.float32,
|
||
tf.complex128: tf.float64
|
||
}.get(operand.dtype)
|
||
if cast_type is not None:
|
||
w = tf.cast(w, cast_type)
|
||
return v, w
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.eigh_p] = _eigh
|
||
|
||
|
||
def _lu(operand: TfVal, _in_avals, _out_aval):
|
||
return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")(
|
||
operand, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.lu_p] = _lu
|
||
|
||
|
||
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
|
||
transpose_a: bool, conjugate_a: bool, unit_diagonal: bool,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
if unit_diagonal:
|
||
a_aval, _ = _in_avals
|
||
a_shape = _eval_shape(a_aval.shape)
|
||
a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype))
|
||
if not left_side:
|
||
rank = len(a.shape)
|
||
transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2]
|
||
a = tf.transpose(a, transpose_dimensions)
|
||
b = tf.transpose(b, transpose_dimensions)
|
||
lower = not lower
|
||
# adjoint == transpose for real dtypes, so special care need only be taken
|
||
# for complex types.
|
||
if a.dtype in [tf.complex64, tf.complex128]:
|
||
if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a):
|
||
a = tf.math.conj(a)
|
||
result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a)
|
||
if not left_side:
|
||
result = tf.transpose(result, transpose_dimensions)
|
||
return result
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.triangular_solve_p] = _triangular_solve
|
||
|
||
|
||
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
|
||
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl,
|
||
extra_name_stack="linear_solve")(
|
||
*args,
|
||
const_lengths=const_lengths,
|
||
jaxprs=jaxprs,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linear_solve_p] = _linear_solve
|
||
|
||
def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
|
||
return _convert_jax_impl(lax_linalg._tridiagonal_solve_jax,
|
||
multiple_results=False,
|
||
extra_name_stack="tridiagonal_solve")(
|
||
*args,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
|
||
|
||
def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
|
||
jvp_jaxpr_thunk: Callable,
|
||
num_consts: int) -> Sequence[TfVal]:
|
||
# TODO(necula): ensure that there is no AD transformation in scope
|
||
del jvp_jaxpr_thunk, num_consts
|
||
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp")
|
||
|
||
|
||
tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call
|
||
|
||
|
||
def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
|
||
**_) -> Sequence[TfVal]:
|
||
# TODO(necula): ensure that there is no AD transformation in scope
|
||
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp")
|
||
|
||
|
||
tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr
|
||
|
||
|
||
def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
|
||
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
|
||
"function.")
|
||
|
||
|
||
tf_impl[ad.custom_lin_p] = _custom_lin
|
||
|
||
|
||
def split_to_logical_devices(tensor: TfVal,
|
||
partition_dimensions: pxla.PartitionsOrReplicated):
|
||
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
|
||
|
||
For jax2tf purposes we want to avoid needing to thread the `strategy` object
|
||
through the generated computation. It seems that the original function needs
|
||
the strategy object only for error checking, which we assume is done upstream
|
||
by JAX.
|
||
|
||
Args:
|
||
tensor: Input tensor to annotate.
|
||
partition_dimensions: A list of integers, with one integer per tensor
|
||
dimension, specifying in how many parts the dimension should be split. The
|
||
product of integers must equal the number of devices per replica.
|
||
use_sharding_op: whether to use a sharding op, or not.
|
||
|
||
Returns:
|
||
an annotated tensor.
|
||
"""
|
||
# TODO: this is only for sharded_jit. Either remove, or implement in terms
|
||
# of _shard_values.
|
||
if partition_dimensions is None:
|
||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||
num_partition_splits = np.prod(partition_dimensions)
|
||
tile_assignment = np.arange(num_partition_splits).reshape(
|
||
partition_dimensions)
|
||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||
|
||
|
||
def _shard_value(val: TfVal,
|
||
aval: core.ShapedArray,
|
||
sd: sharding.XLACompatibleSharding) -> TfVal:
|
||
"""Apply sharding to a TfVal."""
|
||
if pxla._is_unspecified(sd):
|
||
return val
|
||
|
||
sharding_proto: xla_client.OpSharding = cast(
|
||
xla_client.OpSharding, sd._to_xla_op_sharding(aval.ndim))
|
||
# Do not apply XlaSharding for REPLICATED. This is an agreed convention, and
|
||
# also improves usability under TF eager. See b/255511660.
|
||
if pxla.is_op_sharding_replicated(sharding_proto):
|
||
return val
|
||
|
||
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||
xla_sharding_proto: xla_data_pb2.OpSharding = (
|
||
xla_data_pb2.OpSharding(
|
||
type=int(sharding_proto.type),
|
||
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||
tile_assignment_devices=sharding_proto.tile_assignment_devices,
|
||
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
|
||
last_tile_dims=sharding_proto.last_tile_dims))
|
||
if tf_context.executing_eagerly():
|
||
raise ValueError(
|
||
"A jit function with sharded (not replicated) arguments or results must be used under a `tf.function` context. "
|
||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
|
||
|
||
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
|
||
val, use_sharding_op=True)
|
||
|
||
|
||
def _pjit(*args: TfVal,
|
||
jaxpr: core.ClosedJaxpr,
|
||
in_shardings: Sequence[sharding.XLACompatibleSharding],
|
||
out_shardings: Sequence[sharding.XLACompatibleSharding],
|
||
resource_env: maps.ResourceEnv,
|
||
donated_invars,
|
||
name: str,
|
||
in_positional_semantics,
|
||
out_positional_semantics,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||
del donated_invars
|
||
if resource_env.physical_mesh.is_multi_process:
|
||
raise NotImplementedError("jax2tf translation for pjit over multi-process "
|
||
"meshes is not supported yet")
|
||
# Apply sharding annotation to the arguments
|
||
sharded_args: Sequence[TfVal] = tuple(
|
||
map(_shard_value, args, _in_avals, in_shardings))
|
||
results = _interpret_jaxpr(jaxpr, *sharded_args,
|
||
extra_name_stack=util.wrap_name(name, "pjit"))
|
||
sharded_results: Sequence[TfVal] = tuple(
|
||
map(_shard_value, results, _out_aval, out_shardings))
|
||
return tuple(sharded_results)
|
||
|
||
|
||
tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||
|
||
|
||
def _pjit_sharding_constraint(arg: TfVal, *,
|
||
sharding: sharding.MeshPspecSharding,
|
||
resource_env: maps.ResourceEnv,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,
|
||
**kwargs) -> TfVal:
|
||
return _shard_value(arg, _in_avals[0], sharding)
|
||
|
||
|
||
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||
|
||
def _dimension_size_jax2tf(op: TfVal, *, dimension):
|
||
return tf.shape(op)[dimension]
|
||
|
||
tf_impl[shape_poly.dimension_size_p] = _dimension_size_jax2tf
|
||
|
||
def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
|
||
dim_tf, = _eval_shape((dim,))
|
||
return dim_tf
|
||
|
||
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
|
||
|
||
def _register_checkpoint_pytrees():
|
||
"""Registers TF custom container types as pytrees."""
|
||
m = tf.Module()
|
||
# The types here are automagically changed by TensorFlow's checkpointing
|
||
# infrastructure.
|
||
m.a = (tf.Module(), tf.Module())
|
||
m.b = [tf.Module(), tf.Module()]
|
||
m.c = {"a": tf.Module()}
|
||
tuple_wrapper = type(m.a)
|
||
list_wrapper = type(m.b)
|
||
dict_wrapper = type(m.c)
|
||
|
||
# TF AutoTrackable swaps container types out for wrappers.
|
||
assert tuple_wrapper is not tuple
|
||
assert list_wrapper is not list
|
||
assert dict_wrapper is not dict
|
||
|
||
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
|
||
(tuple(xs), None), lambda _, xs: tuple(xs))
|
||
|
||
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
|
||
lambda _, xs: list(xs))
|
||
|
||
jax.tree_util.register_pytree_node(
|
||
dict_wrapper,
|
||
lambda s: (tuple(s.values()), tuple(s.keys())),
|
||
lambda k, xs: dict_wrapper(zip(k, xs)))
|
||
|
||
|
||
_register_checkpoint_pytrees()
|