mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive. PiperOrigin-RevId: 448047584
2759 lines
106 KiB
Python
2759 lines
106 KiB
Python
# Copyright 2020 Google LLC
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# https://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||
from functools import partial
|
||
import contextlib
|
||
import os
|
||
import re
|
||
import threading
|
||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||
|
||
import jax
|
||
from jax import lax
|
||
from jax import config
|
||
from jax import core, custom_derivatives
|
||
from jax import linear_util as lu
|
||
from jax import random, tree_util
|
||
from jax import numpy as jnp
|
||
from jax.experimental import maps
|
||
from jax.experimental import pjit
|
||
from jax.interpreters import ad
|
||
from jax.interpreters import partial_eval
|
||
from jax.interpreters import pxla
|
||
from jax.interpreters import xla
|
||
|
||
import jax._src.prng
|
||
import jax._src.random
|
||
from jax._src import ad_checkpoint
|
||
from jax._src import ad_util
|
||
from jax._src import api
|
||
from jax._src import api_util
|
||
from jax._src import dispatch
|
||
from jax._src import dtypes
|
||
from jax._src import source_info_util
|
||
from jax._src import util
|
||
from jax._src.lax import control_flow as lax_control_flow
|
||
from jax._src.lax import lax as lax_internal
|
||
from jax._src.lax import linalg as lax_linalg
|
||
from jax._src.lax import slicing as lax_slicing
|
||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||
from jax._src.lib import xla_client
|
||
|
||
from jax.experimental.jax2tf import shape_poly
|
||
from jax.experimental.jax2tf import impl_no_xla
|
||
|
||
import numpy as np
|
||
import tensorflow as tf # type: ignore[import]
|
||
|
||
# These don't have public equivalents.
|
||
# pylint: disable=g-direct-tensorflow-import
|
||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
|
||
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
|
||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
|
||
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
|
||
# pylint: enable=g-direct-tensorflow-import
|
||
|
||
NameStack = source_info_util.NameStack
|
||
PolyShape = shape_poly.PolyShape
|
||
|
||
# A temporary internal flag, to enable the wrapping of jax.jit functions
|
||
# with tf.function(jit_compile=True). See #7389. This change has triggered a
|
||
# number of failures in TF. We keep this until we are confident that it does
|
||
# not create problems.
|
||
# TODO(b/207464757): figure out why this change breaks test
|
||
_WRAP_JAX_JIT_WITH_TF_FUNCTION = False
|
||
|
||
# The scope name need to be a valid TensorFlow name. See
|
||
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
|
||
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
|
||
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
|
||
|
||
map = util.safe_map
|
||
zip = util.safe_zip
|
||
|
||
|
||
def _sanitize_scope_name(name):
|
||
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
|
||
if not _VALID_SCOPE_REGEX.match(scope_name):
|
||
scope_name = ".{}".format(scope_name)
|
||
return scope_name
|
||
|
||
|
||
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||
TfVal = Any
|
||
DType = Any
|
||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||
|
||
def _is_tfval(v: TfVal) -> bool:
|
||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||
return True
|
||
try:
|
||
# Include all convertible types, even if not supported on accelerators.
|
||
with tf.device("CPU"):
|
||
tf.constant(v)
|
||
return True
|
||
except:
|
||
return False
|
||
|
||
|
||
# The implementation rules for primitives. The rule will be called with the
|
||
# arguments (TfVal) and must return TfVal (or a sequence thereof,
|
||
# if primitive.multiple_results). The exception are primarily the
|
||
# control-flow primitives.
|
||
tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
|
||
|
||
# Some primitive implementation rules need the abstract values of arguments
|
||
# and the results. This is the case for the primitives implemented using
|
||
# _convert_jax_impl and those that need to adjust the shape of the outputs
|
||
# due to missing TF shape inference rules for TFXLA ops. The rules for these
|
||
# primitives should be added to `tf_impl_with_avals`.
|
||
# The abstract value are passed to the implementation as two special kwargs
|
||
# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a
|
||
# core.ShapedArray, or a tuple thereof when primitive.multiple_results).
|
||
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
|
||
|
||
# XLA is not linked in all environments when converting a primitive. If this is
|
||
# the case, we first search for implementation rules for primitives in the
|
||
# following map. These implementations are workarounds, making use of TF ops
|
||
# that do work when XLA is not linked in.
|
||
tf_impl_no_xla = impl_no_xla.tf_impl_no_xla
|
||
|
||
# In order to ensure that JAX picks up the proper user-frame for source
|
||
# locations we will register the TensorFlow source path as an internal
|
||
# path with source_info_util. The typical stack when a JAX primitive
|
||
# conversion happens is:
|
||
# jax2tf.process_primitive (top of stack)
|
||
# jax tracing machinery ...
|
||
# tf.custom_gradient machinery ...
|
||
# jax2tf.converted_fun
|
||
# tf function machinery ...
|
||
# user code invokes the converted function on TF tensors
|
||
#
|
||
# We need to skip over not only JAX internal frames, but TF internal frames
|
||
# also.
|
||
# We register the TensorFlow source path lazily
|
||
_has_registered_tf_source_path = False
|
||
|
||
class _ThreadLocalState(threading.local):
|
||
def __init__(self):
|
||
self.name_stack = ""
|
||
# XLA is not linked in all environments; when converting a primitive, if this
|
||
# variable is disabled, we try harder to use only standard TF ops if they are
|
||
# applicable to the concrete use case; if the resulting conversion path ends up
|
||
# requiring a TFXLA operation, an exception is thrown instead.
|
||
self.enable_xla = True
|
||
|
||
# Keep track if we are inside a call_tf. In that context we disable the
|
||
# safety check that we are not inside JAX transformations.
|
||
self.inside_call_tf = False
|
||
|
||
# Maps dimension variables to TF expressions
|
||
self.shape_env: Sequence[Tuple[str, TfVal]] = ()
|
||
|
||
# Whether to actually include XLA op metadata in the generated TF ops
|
||
# TODO(b/189306134): implement support for XLA metadata
|
||
self.include_xla_op_metadata = False
|
||
|
||
# A cache for the tf.convert_to_tensor for constants. We try to preserve
|
||
# sharing for constants, to enable tf.Graph to take advantage of it.
|
||
# See https://github.com/google/jax/issues/7992.
|
||
self.constant_cache = None # None means that we don't use a cache. We
|
||
# may be outside a conversion scope.
|
||
|
||
|
||
_thread_local_state = _ThreadLocalState()
|
||
|
||
def _get_current_name_stack() -> Union[NameStack, str]:
|
||
if config.jax_experimental_name_stack:
|
||
return source_info_util.current_name_stack()
|
||
return _thread_local_state.name_stack
|
||
|
||
@contextlib.contextmanager
|
||
def inside_call_tf():
|
||
# Set the inside_call_tf flag for a context.
|
||
prev = _thread_local_state.inside_call_tf
|
||
_thread_local_state.inside_call_tf = True
|
||
try:
|
||
yield
|
||
finally:
|
||
_thread_local_state.inside_call_tf = prev
|
||
|
||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||
def convert(fun: Callable,
|
||
*,
|
||
polymorphic_shapes=None,
|
||
with_gradient=True,
|
||
enable_xla=True
|
||
) -> Callable:
|
||
"""Transforms `fun` to be executed by TensorFlow.
|
||
|
||
See
|
||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||
for more details about usage and common problems.
|
||
|
||
Args:
|
||
fun: Function to be transformed. Its arguments and return value should be
|
||
JAX arrays, or nested standard Python containers (tuple/list/dict) thereof
|
||
(pytrees).
|
||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||
during conversion.
|
||
|
||
.. warning:: The shape-polymorphic conversion is an experimental feature.
|
||
It is meant to be sound, but it is known to reject some JAX programs
|
||
that are shape polymorphic. The details of this feature can change.
|
||
|
||
It should be `None` (all arguments are monomorphic), a single PolyShape
|
||
or string (applies to all arguments), or a tuple/list of the same length
|
||
as the function arguments. For each argument the shape specification
|
||
should be `None` (monomorphic argument), or a Python object with the
|
||
same pytree structure as the argument.
|
||
See [how optional parameters are matched to
|
||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||
|
||
A shape specification for an array argument should be an object
|
||
`PolyShape(dim0, dim1, ..., dimn)`
|
||
where each `dim` is a dimension specification: a positive integer denoting
|
||
a monomorphic dimension of the given size, or a string denoting a
|
||
dimension variable assumed to range over non-zero dimension sizes, or
|
||
the special placeholder string "_" denoting a monomorphic dimension
|
||
whose size is given by the actual argument. As a shortcut, an Ellipsis
|
||
suffix in the list of dimension specifications stands for a list of "_"
|
||
placeholders.
|
||
|
||
For convenience, a shape specification can also be given as a string
|
||
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
|
||
with surrounding parentheses: "(batch, ...)".
|
||
|
||
The conversion fails if it cannot ensure that the it would produce the same
|
||
sequence of TF ops for any non-zero values of the dimension variables.
|
||
|
||
polymorphic_shapes are only supported for positional arguments; shape
|
||
polymorphism is not supported for keyword arguments.
|
||
|
||
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||
for more details.
|
||
|
||
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
|
||
with_gradient: if set (default), add a tf.custom_gradient to the converted
|
||
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
||
TensorFlow AD is supported for the output TensorFlow function, and the
|
||
value of the gradient will be JAX-accurate.
|
||
enable_xla: if set (default), the converter will use the simplest conversion
|
||
and use XLA TF ops when necessary. These ops are known to create issues
|
||
for the TFLite and TFjs converters. For those cases, unset this parameter
|
||
so the converter tries harder to use non-XLA TF ops to convert the
|
||
function and aborts if this is not possible.
|
||
|
||
Returns:
|
||
A version of `fun` that expects TfVals as arguments (or
|
||
tuple/lists/dicts) thereof, and returns TfVals as outputs, and uses
|
||
only TensorFlow ops.
|
||
"""
|
||
api._check_callable(fun)
|
||
fun_name = getattr(fun, "__name__", "unknown")
|
||
name_stack = util.wrap_name(fun_name, "jax2tf")
|
||
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
|
||
# TODO: is there a better way to check if we are inside a transformation?
|
||
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||
# It is Ok to nest convert when we are inside a call_tf
|
||
raise ValueError("convert must be used outside all JAX transformations." +
|
||
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
|
||
|
||
# We support kwargs by wrapping the function to take only positional arguments.
|
||
# This is in part because jax.vjp does not support kwargs.
|
||
nr_positional_args = len(args)
|
||
kw_names = kwargs.keys()
|
||
args = tuple(args) + tuple(kwargs[kw] for kw in kw_names)
|
||
|
||
def fun_no_kwargs(*args_and_kwargs):
|
||
assert len(args_and_kwargs) == nr_positional_args + len(kw_names)
|
||
args = args_and_kwargs[:nr_positional_args]
|
||
kwargs = {kw: args_and_kwargs[nr_positional_args + i]
|
||
for i, kw in enumerate(kw_names)}
|
||
return fun(*args, **kwargs)
|
||
|
||
def check_arg(a):
|
||
if not _is_tfval(a):
|
||
msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
|
||
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
|
||
raise TypeError(msg)
|
||
|
||
tree_util.tree_map(check_arg, args)
|
||
|
||
args_flat, in_tree = tree_util.tree_flatten((args, {}))
|
||
# May need to cast the arguments to have the type assumed by JAX
|
||
args_and_dtypes_flat = tuple(map(_tfval_to_tensor_jax_dtype, args_flat))
|
||
args_flat, arg_dtypes_flat = util.unzip2(args_and_dtypes_flat)
|
||
# Name input tensors; do this after we have cast the arguments
|
||
def _apply_name(a: TfVal, suffix) -> TfVal:
|
||
return tf.identity(a, f"jax2tf_arg_{suffix}")
|
||
args_flat = tuple(_apply_name(a, i) for i, a in enumerate(args_flat))
|
||
|
||
if polymorphic_shapes is None:
|
||
polymorphic_shapes_ = (polymorphic_shapes,) * len(args)
|
||
elif isinstance(polymorphic_shapes, (PolyShape, str)):
|
||
polymorphic_shapes_ = (polymorphic_shapes,) * len(args) # type: ignore
|
||
else:
|
||
if not isinstance(polymorphic_shapes, Sequence) or len(polymorphic_shapes) != len(args) - len(kw_names):
|
||
msg = ("polymorphic_shapes must be a sequence with the same length as the positional argument list "
|
||
f"({len(args)}). Got polymorphic_shapes={repr(polymorphic_shapes)}.")
|
||
raise TypeError(msg)
|
||
polymorphic_shapes_ = tuple(polymorphic_shapes) + (None,) * len(kw_names)
|
||
|
||
# Expand the polymorphic_shapes to match the argument pytree
|
||
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
|
||
in_tree.children()[0],
|
||
polymorphic_shapes_))
|
||
|
||
def fix_tf1_shape(arg: TfVal) -> Sequence[Optional[int]]:
|
||
tf_arg_shape = np.shape(arg)
|
||
return tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
|
||
args_shapes_flat = tuple(fix_tf1_shape(a) for a in args_flat)
|
||
|
||
# Construct the abstract values for the flat arguments, possibly based on
|
||
# the input shapes and the polymorphic_shapes if given. May create new shape
|
||
# variables. May cast the args_flat to JAX types, using JAX's interpretation
|
||
# of types of constants.
|
||
args_avals_flat = shape_poly.args_avals(
|
||
args_shapes_flat, arg_dtypes_flat, polymorphic_shapes_flat)
|
||
|
||
dim_vars, get_dim_values = shape_poly.prepare_dim_var_env(args_avals_flat)
|
||
dim_values, _ = util.unzip2(_interpret_fun(lu.wrap_init(get_dim_values),
|
||
args_flat, args_avals_flat, ""))
|
||
shape_env = zip(dim_vars, dim_values)
|
||
|
||
# This function may take pytrees of TfVals. We can only set
|
||
# tf.custom_gradient on functions that take a flat argument list.
|
||
f = lu.wrap_init(fun_no_kwargs)
|
||
# out_tree_thunk() will be the output tree, after running _interpret_fun.
|
||
flat_fun, out_tree_thunk = api_util.flatten_fun(f, in_tree)
|
||
# out_tree_thunk will be ready after _interpret_fun below.
|
||
|
||
# Prepare the grad_fn for tf.custom_gradient.
|
||
def converted_grad_fn(*out_cts_flat: TfVal,
|
||
_out_cts_avals: Sequence[core.ShapedArray],
|
||
variables=None):
|
||
if variables:
|
||
raise ValueError(
|
||
"Unexpected variables used in forward pass. "
|
||
"This should not happen for first-order differentiation. "
|
||
f"variables={variables}")
|
||
|
||
out_tree = out_tree_thunk()
|
||
if polymorphic_shapes is None:
|
||
vjp_polymorphic_shapes = None
|
||
else:
|
||
args_flat_polymorphic_shapes = polymorphic_shapes_flat
|
||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
|
||
for out_aval in _out_cts_avals) # type: ignore
|
||
vjp_polymorphic_shapes = [
|
||
args_flat_polymorphic_shapes, out_cts_flat_polymorphic_shapes
|
||
]
|
||
|
||
def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
|
||
# One may think that we can get the pullback while we are converting
|
||
# the main function in the first place. That is problematic, because the
|
||
# pullback may contain captured tracers from the conversion of the
|
||
# main function. Those tracers will confuse the conversion of the
|
||
# pullback. So, we construct the vjp anew and we convert it separately.
|
||
args_jax, kwargs_jax = tree_util.tree_unflatten(in_tree, args_flat_jax)
|
||
assert not kwargs_jax
|
||
_, pullback_jax = jax.vjp(fun_no_kwargs, *args_jax)
|
||
|
||
def fix_out_ct(out_ct_jax, out_ct_aval: core.ShapedArray):
|
||
# If the primal function has outputs of integer or bool types, and if we are
|
||
# under a tf.function context, then TF will pass None in _out_cts_flat
|
||
# in place of these values. We should change these to float0 or
|
||
# else JAX gets unhappy. See issue #6975.
|
||
if out_ct_jax is not None:
|
||
return out_ct_jax
|
||
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"out_ct={out_ct_jax}"
|
||
# Note that out_ct_aval.shape contains dimension variable from the
|
||
# primal function scope. It is Ok to use them here because we
|
||
# use the same shape variables for the VJP function.
|
||
return jnp.zeros(out_ct_aval.shape, dtype=_tf_np_dtype_for_float0)
|
||
|
||
out_cts_fixed_flat = tuple(map(fix_out_ct, out_cts_flat_jax, _out_cts_avals))
|
||
|
||
out_cts_fixed = tree_util.tree_unflatten(out_tree, out_cts_fixed_flat)
|
||
in_cts_jax = pullback_jax(out_cts_fixed)
|
||
|
||
in_cts_flat_jax, in_cts_tree = tree_util.tree_flatten(in_cts_jax)
|
||
def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
|
||
if jnp.issubdtype(arg_aval.dtype, jnp.inexact):
|
||
return in_ct
|
||
else:
|
||
assert in_ct.dtype == dtypes.float0
|
||
return jnp.zeros(arg_aval.shape, _tf_np_dtype_for_float0)
|
||
|
||
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
|
||
return in_cts_fixed_flat_jax
|
||
|
||
# TODO: enable higher-order gradients
|
||
with tf.name_scope("jax2tf_vjp"):
|
||
in_cts_flat = convert(
|
||
fun_vjp_jax,
|
||
with_gradient=False,
|
||
polymorphic_shapes=vjp_polymorphic_shapes)(args_flat, out_cts_flat)
|
||
return in_cts_flat
|
||
|
||
try:
|
||
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
|
||
|
||
prev_enable_xla = _thread_local_state.enable_xla
|
||
_thread_local_state.enable_xla = enable_xla
|
||
|
||
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
|
||
# TODO(b/189306134): implement support for XLA metadata
|
||
_thread_local_state.include_xla_op_metadata = False
|
||
|
||
_thread_local_state.shape_env = shape_env
|
||
global _has_registered_tf_source_path
|
||
if not _has_registered_tf_source_path:
|
||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||
_has_registered_tf_source_path = True
|
||
|
||
if with_gradient:
|
||
|
||
@tf.custom_gradient
|
||
def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:
|
||
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
|
||
name_stack,
|
||
fresh_constant_cache=True)
|
||
outs, out_avals = util.unzip2(out_with_avals)
|
||
return (tuple(outs),
|
||
partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))
|
||
|
||
out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
|
||
else:
|
||
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
|
||
name_stack, fresh_constant_cache=True)
|
||
outs, out_avals = util.unzip2(out_with_avals)
|
||
message = ("The jax2tf-converted function does not support gradients. "
|
||
"Use `with_gradient` parameter to enable gradients")
|
||
# We use PreventGradient, which is propagated through a SavedModel.
|
||
out_flat = [
|
||
tf.raw_ops.PreventGradient(input=o, message=message)
|
||
for o in outs
|
||
]
|
||
finally:
|
||
_thread_local_state.shape_env = ()
|
||
_thread_local_state.enable_xla = prev_enable_xla
|
||
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
|
||
|
||
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
|
||
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||
return out
|
||
|
||
return converted_fun
|
||
|
||
|
||
def dtype_of_val(val: TfVal) -> DType:
|
||
"""Computes the TensorFlow dtype using JAX's typing rules.
|
||
|
||
If the value is a tf.Tensor, it starts with its dtype. If the value is a
|
||
constant it uses JAX to infer its dtype. The resulting dtype follows the
|
||
JAX type inference rules, and depends on the value of the
|
||
JAX_ENABLE_X64 flag.
|
||
|
||
See README.md for how 64-bit values are treated.
|
||
"""
|
||
tval, _ = _tfval_to_tensor_jax_dtype(val)
|
||
return tval.dtype
|
||
|
||
# Internals
|
||
|
||
@contextlib.contextmanager
|
||
def _extended_name_stack(extra_name_stack: Optional[str]):
|
||
if config.jax_experimental_name_stack:
|
||
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
|
||
if extra_name_stack
|
||
else contextlib.nullcontext())
|
||
with name_ctx:
|
||
yield
|
||
return
|
||
prev_name_stack = _thread_local_state.name_stack
|
||
if extra_name_stack:
|
||
if not prev_name_stack:
|
||
_thread_local_state.name_stack = extra_name_stack
|
||
else:
|
||
_thread_local_state.name_stack = util.extend_name_stack(
|
||
_thread_local_state.name_stack, extra_name_stack)
|
||
try:
|
||
yield
|
||
finally:
|
||
_thread_local_state.name_stack = prev_name_stack
|
||
|
||
|
||
def _interpret_fun(
|
||
fun: lu.WrappedFun, in_vals: Sequence[TfVal],
|
||
in_avals: Sequence[core.ShapedArray],
|
||
extra_name_stack: Optional[str],
|
||
fresh_constant_cache: bool = False
|
||
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
|
||
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
|
||
fun = _interpret_subtrace(fun, main, in_avals)
|
||
with _extended_name_stack(extra_name_stack):
|
||
with core.new_sublevel():
|
||
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
_call_wrapped_with_new_constant_cache(fun, in_vals,
|
||
fresh_constant_cache=fresh_constant_cache)
|
||
|
||
del main
|
||
|
||
return tuple(out_vals)
|
||
|
||
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||
in_vals: Sequence[TfVal],
|
||
fresh_constant_cache: bool = False
|
||
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
|
||
try:
|
||
prev_constant_cache = _thread_local_state.constant_cache
|
||
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
|
||
# Start a new cache, so that we don't share constants across tf.function
|
||
# boundaries.
|
||
if fresh_constant_cache:
|
||
_thread_local_state.constant_cache = {}
|
||
|
||
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
fun.call_wrapped(*in_vals)
|
||
finally:
|
||
if prev_constant_cache is not None and not fresh_constant_cache:
|
||
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
|
||
# Delete the newly added keys
|
||
for k in newly_added_keys:
|
||
del prev_constant_cache[k]
|
||
_thread_local_state.constant_cache = prev_constant_cache
|
||
return out_vals
|
||
|
||
def _convert_jax_impl(jax_impl: Callable, *,
|
||
multiple_results=True,
|
||
extra_name_stack: Optional[str] = None) -> Callable:
|
||
"""Convert the JAX implementation of a primitive.
|
||
|
||
Args:
|
||
jax_impl: typically the impl-rule for a primitive, with signature
|
||
`(*args: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
|
||
a primitive in terms of other primitives.
|
||
multiple_results: whether `jax_impl` returns a sequence of results.
|
||
extra_name_stack: additional element to add to the name stack for the
|
||
converted ops.
|
||
|
||
Returns:
|
||
a function with signature `(*args: TfVal, _in_avals, _out_aval, **kwargs)
|
||
-> Sequence[TfVal]`.
|
||
"""
|
||
|
||
def wrapped(*tf_args: TfVal, _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,
|
||
**kwargs) -> Sequence[TfVal]:
|
||
|
||
# We wrap the jax_impl under _interpret_fun to abstract the TF values
|
||
# from jax_impl and turn them into JAX abstract values.
|
||
def jax_impl_jax_args(*jax_args):
|
||
jax_results = jax_impl(*jax_args, **kwargs)
|
||
return jax_results if multiple_results else [jax_results]
|
||
|
||
tf_results_with_avals = _interpret_fun(
|
||
lu.wrap_init(jax_impl_jax_args), tf_args, _in_avals,
|
||
extra_name_stack)
|
||
tf_results, _ = util.unzip2(tf_results_with_avals)
|
||
return tf_results if multiple_results else tf_results[0]
|
||
|
||
return wrapped
|
||
|
||
|
||
@lu.transformation
|
||
def _interpret_subtrace(main: core.MainTrace,
|
||
in_avals: Sequence[core.ShapedArray],
|
||
*in_vals: TfVal):
|
||
trace = TensorFlowTrace(main, core.cur_sublevel())
|
||
in_tracers = tuple(
|
||
TensorFlowTracer(trace, val, aval)
|
||
for val, aval in zip(in_vals, in_avals))
|
||
outs = yield in_tracers, {} # type: Sequence[TfVal]
|
||
out_tracers: Iterable[TensorFlowTracer] = (
|
||
map(trace.full_raise, outs)) # type: ignore
|
||
out_vals_with_avals: Sequence[Tuple[TfVal, core.ShapedArray]] = (
|
||
tuple((t.val, t.aval) for t in out_tracers))
|
||
yield out_vals_with_avals
|
||
|
||
|
||
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal,
|
||
extra_name_stack: Optional[str]) -> Sequence[TfVal]:
|
||
"""Evaluates a Jaxpr with tf.Tensor arguments.
|
||
|
||
The output is a sequence of TfVal, suitable for use with TF.
|
||
"""
|
||
fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||
out_with_avals = _interpret_fun(fun, args, jaxpr.in_avals, extra_name_stack)
|
||
return tuple(v for v, _ in out_with_avals)
|
||
|
||
|
||
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
|
||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
|
||
aval.shape)) # type: ignore[attr-defined]
|
||
|
||
# In the TF world, we represent float0 as zeros of this type.
|
||
_tf_np_dtype_for_float0 = np.int32
|
||
|
||
def _to_tf_dtype(jax_dtype):
|
||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||
# due to float0 and 64-bit behavior.
|
||
if jax_dtype == dtypes.float0:
|
||
jax_dtype = _tf_np_dtype_for_float0
|
||
return tf.dtypes.as_dtype(jax_dtype)
|
||
|
||
|
||
def _to_jax_dtype(tf_dtype):
|
||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||
# due to float0 and 64-bit behavior.
|
||
return dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
|
||
|
||
|
||
def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||
jax_dtype: Optional[DType] = None,
|
||
memoize_constants=False) -> Tuple[TfVal, DType]:
|
||
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
|
||
|
||
If `jax_dtype` is missing, uses JAX typing rules.
|
||
See README.md for details regarding 64-bit values.
|
||
|
||
Args:
|
||
val: a scalar, ndarray, tf.Tensor, or tf.Variable
|
||
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
|
||
rules for constants.
|
||
memoize_constants: whether to memoize TF constants. We can't do this
|
||
everywhere, we may be outside of a conversion scope.
|
||
|
||
Returns:
|
||
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
|
||
"""
|
||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
|
||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||
if conversion_dtype != val.dtype:
|
||
return tf.cast(val, conversion_dtype), jax_dtype
|
||
else:
|
||
return val, jax_dtype
|
||
else: # A constant
|
||
jax_dtype = jax_dtype or xla.abstractify(val).dtype
|
||
# TODO(document): We assume that the value of a constant does not
|
||
# change through the scope of the function. But it may be an ndarray, ...
|
||
# JAX has the same problem when generating HLO.
|
||
const_key = (id(val), jax_dtype)
|
||
# Since we use id(val) as a cache key, we have to make sure that we keep
|
||
# the previous `val` alive. Otherwise, for an ndarray, it can get garbage
|
||
# collected and reused for a different value, which would create correctness
|
||
# issues. We keep the `val` alive by storing in the cache the pair
|
||
# `(val, tf_val)`.
|
||
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
|
||
if do_memoize:
|
||
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
|
||
else:
|
||
tf_val = None
|
||
if tf_val is None:
|
||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||
# The float0 type is not known to TF.
|
||
if jax_dtype == dtypes.float0:
|
||
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
|
||
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
|
||
if do_memoize:
|
||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||
return tf_val, jax_dtype
|
||
|
||
|
||
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||
assert all(map(lambda x: x is not None, shape)), (
|
||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||
eval_shape, dim_avals = shape_poly.get_shape_evaluator(dim_vars, shape)
|
||
shape_values, _ = util.unzip2(_interpret_fun(lu.wrap_init(eval_shape),
|
||
dim_values, dim_avals, "")) # type: ignore
|
||
return shape_values
|
||
|
||
|
||
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
|
||
# pylint: disable=assigning-non-slot
|
||
|
||
|
||
class TensorFlowTracer(core.Tracer):
|
||
"""Tracer class that boxes a TF value and a JAX abstract value.
|
||
|
||
In addition to the TF value we carry the JAX abstract value because there is
|
||
one case when it cannot be recovered from the value: when we are converting
|
||
with polymorphic shapes, in which case the shape of the value may have
|
||
dimensions set to `None`, which the JAX abstract value may contain more
|
||
precise information.
|
||
|
||
When the value has a partially-known shape, the dimensions marked as `None`
|
||
must correspond to non-constant dimensions in the abstract value.
|
||
|
||
See README.md for details.
|
||
"""
|
||
# val: TfVal
|
||
# _aval: core.ShapedArray
|
||
__slots__ = ["val", "_aval"]
|
||
|
||
def __init__(self, trace: "TensorFlowTrace", val: TfVal,
|
||
aval: core.AbstractValue):
|
||
self._trace = trace
|
||
self._aval = aval
|
||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||
val_shape = val.shape
|
||
|
||
if config.jax_enable_checks:
|
||
assert len(self._aval.shape) == len(val_shape), f"_aval.shape={self._aval.shape} different rank than val_shape={val_shape}"
|
||
# To compare types, we must handle float0 in JAX and x64 in TF
|
||
if self._aval.dtype == dtypes.float0:
|
||
assert _to_tf_dtype(self._aval.dtype) == val.dtype, f"expected {self._aval.dtype} == {val.dtype}"
|
||
else:
|
||
assert self._aval.dtype == _to_jax_dtype(val.dtype), f"expected {self._aval.dtype} == {val.dtype}"
|
||
|
||
for aval_dim, val_dim in zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||
if val_dim is None:
|
||
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||
elif not shape_poly.is_poly_dim(aval_dim):
|
||
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||
else:
|
||
# We have a TF value with known shape, and the abstract shape is a shape variable.
|
||
try:
|
||
aval_int = int(_eval_shape([aval_dim])) # type: ignore
|
||
except (TypeError, KeyError):
|
||
continue
|
||
assert aval_int == val_dim, f"expected {self._aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
|
||
|
||
self.val = _tfval_to_tensor_jax_dtype(val,
|
||
self._aval.dtype,
|
||
memoize_constants=True)[0] # type: ignore[attr-defined]
|
||
|
||
@property
|
||
def aval(self):
|
||
return self._aval
|
||
|
||
def full_lower(self):
|
||
return self
|
||
|
||
|
||
class TensorFlowTrace(core.Trace):
|
||
"""Trace class that underlies the jax2tf transformation.
|
||
|
||
We are going to ensure that jax2tf.convert is never nested inside other
|
||
transformations. This is sufficient for intended use cases (converting
|
||
fully-transformed JAX code). It also simplifies our job because we do not have
|
||
to handle situations where we apply primitives on a mix of TF values and
|
||
JAX tracers from an outer transformation. E.g., for addition both the TF
|
||
values
|
||
and the JAX tracers have an override and they get confused if they see values
|
||
from the other world.
|
||
|
||
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
|
||
higher-order control-flow primitives we invoke recursively
|
||
_interpret_fun on the body of the conditional, which will create a nested TFT.
|
||
|
||
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
|
||
those will introduce their own MainTrace, and any operations involving those
|
||
will be done on those traces, i.e., not a concern for TFT.
|
||
"""
|
||
def pure(self, val: TfVal) -> TensorFlowTracer:
|
||
"""Lifts a non-Tracer into the TensorFlowTracer.
|
||
|
||
This function may be called by way of trace.full_raise.
|
||
"""
|
||
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True)
|
||
return TensorFlowTracer(
|
||
self, val, core.ShapedArray(tf_val.shape, jax_dtype,
|
||
weak_type=dtypes.is_weakly_typed(val)))
|
||
|
||
def lift(self, val: core.Tracer) -> TensorFlowTracer:
|
||
# This would be called when we need to raise a tracer from a lower-level
|
||
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
|
||
# inside another transform, there are no lower-level main traces.
|
||
assert False
|
||
|
||
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
|
||
# This is called when we need to raise a tracer from the same main,
|
||
# but a lower sublevel. This could come from a nested jit.
|
||
return TensorFlowTracer(self, val.val, val._aval)
|
||
|
||
def process_primitive(self, primitive: core.Primitive,
|
||
tracers: Sequence[TensorFlowTracer],
|
||
params) -> TensorFlowTracer:
|
||
impl, impl_needs_avals = self.get_primitive_impl(primitive)
|
||
args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||
# This is a bit conservative, doing abstract_eval even in op-by-op execution
|
||
# but we needed it for, e.g., shape_polymorphism where only JAX's
|
||
# abstract evaluation rules can properly track polymorphic shapes.
|
||
# Unfortunately under op-by-op execution this is a rare occasion where we
|
||
# need abstract evaluation.
|
||
out_aval, _ = primitive.abstract_eval(*args_avals, **params)
|
||
args_tf: Sequence[TfVal] = [t.val for t in tracers]
|
||
def invoke_impl() -> TfVal:
|
||
if impl_needs_avals:
|
||
return impl(
|
||
*args_tf,
|
||
_in_avals=args_avals, # type: ignore
|
||
_out_aval=out_aval,
|
||
**params)
|
||
else:
|
||
return impl(*args_tf, **params)
|
||
|
||
current_name_stack = _get_current_name_stack()
|
||
if config.jax_experimental_name_stack:
|
||
# We don't use `str(name_stack)` because it uses parentheses for
|
||
# transformations, which aren't allowed in `name_scope`.
|
||
scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr]
|
||
else:
|
||
scope = str(current_name_stack)
|
||
# We need to add a '/' to the name stack string to force `tf.name_scope`
|
||
# to interpret it as an absolute scope, not a relative scope.
|
||
scope = scope + '/'
|
||
name_scope = (
|
||
tf.name_scope(_sanitize_scope_name(scope)) if
|
||
config.jax_experimental_name_stack else contextlib.nullcontext())
|
||
|
||
with name_scope:
|
||
if _thread_local_state.include_xla_op_metadata:
|
||
op_metadata = xla.make_op_metadata(primitive, params,
|
||
name_stack=current_name_stack,
|
||
source_info=source_info_util.current())
|
||
op_metadata_proto = xla_data_pb2.OpMetadata(
|
||
op_type=op_metadata.op_type,
|
||
op_name=op_metadata.op_name,
|
||
source_file=op_metadata.source_file,
|
||
source_line=op_metadata.source_line
|
||
)
|
||
with tf_ops.get_default_graph()._attr_scope(
|
||
{"_XlaOpMetadata": attr_value_pb2.AttrValue(
|
||
s=op_metadata_proto.SerializeToString())}):
|
||
val_out = invoke_impl()
|
||
else:
|
||
val_out = invoke_impl()
|
||
|
||
if primitive.multiple_results:
|
||
out = [
|
||
TensorFlowTracer(self, v, a)
|
||
for v, a in zip(val_out, out_aval)
|
||
] # type: ignore
|
||
else:
|
||
out = TensorFlowTracer(self, val_out, out_aval) # type: ignore
|
||
|
||
# Check that the impl rule returned a value of expected shape and dtype
|
||
# TODO: adapt this to match polymorphic shapes
|
||
if config.jax_enable_checks:
|
||
if primitive.multiple_results:
|
||
for o, expected_aval in zip(out, out_aval): # type: ignore
|
||
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
|
||
f"{primitive}: out.aval = {o.aval}; expected {expected_aval}")
|
||
else:
|
||
assert out.aval == out_aval, ( # type: ignore
|
||
f"{primitive}: out.aval = {out.aval}; expected {out_aval}"
|
||
) # type: ignore
|
||
return out # type: ignore
|
||
|
||
def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun,
|
||
tracers: Sequence[TensorFlowTracer], params):
|
||
assert call_primitive.multiple_results
|
||
vals: Sequence[TfVal] = [t.val for t in tracers]
|
||
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
||
extra_name_stack = None
|
||
if call_primitive == core.named_call_p:
|
||
extra_name_stack = util.wrap_name(params["name"], "named")
|
||
elif call_primitive == xla.xla_call_p:
|
||
extra_name_stack = util.wrap_name(params["name"], "jit")
|
||
with _extended_name_stack(extra_name_stack):
|
||
with core.new_sublevel():
|
||
if call_primitive == core.named_call_p:
|
||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||
vals_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
interpreted_fun.call_wrapped(*vals)
|
||
elif call_primitive == xla.xla_call_p:
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
# Make a nested tf.function(jit_compile=True)
|
||
store_tf_res_avals = None
|
||
def f_tf(*tf_args):
|
||
nonlocal store_tf_res_avals
|
||
tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||
_call_wrapped_with_new_constant_cache(interpreted_fun, tf_args,
|
||
fresh_constant_cache=False)
|
||
tf_res_vals, tf_res_avals = util.unzip2(tf_res_out)
|
||
store_tf_res_avals = tf_res_avals
|
||
return tf_res_vals
|
||
tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals)
|
||
vals_out = zip(tf_vals_out, store_tf_res_avals)
|
||
else:
|
||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||
else:
|
||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
||
|
||
def post_process_call(self, call_primitive: core.Primitive,
|
||
out_tracers: Sequence[TensorFlowTracer], params):
|
||
# We encountered a call primitive, e.g., remat_call_p, whose result
|
||
# (out_tracers) include TensorFlowTracer that were not passed through
|
||
# its arguments (captured from the environment).
|
||
vals = tuple(t.val for t in out_tracers)
|
||
main = self.main
|
||
|
||
def todo(vals: Sequence[TfVal]):
|
||
# TODO: is name_stack correct?
|
||
trace = TensorFlowTrace(main, core.cur_sublevel())
|
||
return [
|
||
TensorFlowTracer(trace, v, out_tracer.aval)
|
||
for v, out_tracer in zip(vals, out_tracers)
|
||
]
|
||
|
||
return vals, todo
|
||
|
||
def process_map(self, map_primitive, f, tracers, params):
|
||
raise NotImplementedError("process_map")
|
||
|
||
def post_process_map(self, map_primitive, out_tracers, params):
|
||
raise NotImplementedError("post_process_map")
|
||
|
||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||
# Drop the custom differentiation rule and act like a call primitive. This
|
||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||
# there are no more JAX differentiation transformations to be applied.
|
||
del jvp # Unused.
|
||
return self.process_call(core.call_p, fun, tracers, {})
|
||
|
||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||
# Drop the custom differentiation rule and act like a call primitive. This
|
||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||
# there are no more JAX differentiation transformations to be applied.
|
||
del fwd, bwd, out_trees # Unused.
|
||
return self.process_call(core.call_p, fun, tracers, {})
|
||
|
||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def post_process_custom_vjp_call_fwd(self, *_, **__):
|
||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||
|
||
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
|
||
# Returns the primitive implementation and whether the implementation
|
||
# takes abstract values (see definition of tf_impl_with_avals)
|
||
if not _thread_local_state.enable_xla:
|
||
try:
|
||
return tf_impl_no_xla[p], True # Always require avals.
|
||
except KeyError:
|
||
pass
|
||
try:
|
||
return tf_impl[p], False
|
||
except KeyError:
|
||
try:
|
||
return tf_impl_with_avals[p], True
|
||
except KeyError as err:
|
||
msg = "TensorFlow interpretation rule for '{}' not implemented"
|
||
raise NotImplementedError(msg.format(p)) from err
|
||
|
||
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||
assert False, f"Encountered unexpected primitive {p}"
|
||
|
||
|
||
# Call primitives are inlined
|
||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p,
|
||
partial_eval.remat_call_p, maps.xmap_p]:
|
||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||
|
||
# Primitives that are not yet implemented must be explicitly declared here.
|
||
tf_not_yet_impl = [
|
||
"clz",
|
||
"igamma_grad_a",
|
||
"random_gamma_grad",
|
||
"reduce_precision",
|
||
"schur",
|
||
"name",
|
||
"unreachable",
|
||
"bint",
|
||
"getslice",
|
||
"full_to_shard",
|
||
"shard_to_full",
|
||
|
||
# Not high priority?
|
||
"after_all",
|
||
"all_to_all",
|
||
"create_token",
|
||
"custom_transpose_call",
|
||
"custom_vmap_call",
|
||
"infeed",
|
||
"linear_call",
|
||
"outfeed",
|
||
"pmax_p",
|
||
"pmin",
|
||
"ppermute",
|
||
"psum",
|
||
"pmax",
|
||
"pgather",
|
||
"reduce_scatter",
|
||
"axis_index",
|
||
"pdot",
|
||
"all_gather",
|
||
"lu_pivots_to_permutation",
|
||
"xla_pmap",
|
||
]
|
||
|
||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
|
||
|
||
|
||
def _add(x: TfVal, y: TfVal) -> TfVal:
|
||
return tf.raw_ops.AddV2(x=x, y=y)
|
||
|
||
|
||
tf_impl[ad_util.add_jaxvals_p] = _add
|
||
tf_impl[dispatch.device_put_p] = lambda x, device=None: x
|
||
tf_impl[lax_internal.copy_p] = lambda x: x
|
||
|
||
def _neg(x: TfVal) -> TfVal:
|
||
if x.dtype.is_unsigned:
|
||
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[x.dtype]
|
||
x_signed = tf.cast(x, signed_dtype)
|
||
res_signed = tf.math.negative(x_signed)
|
||
return tf.cast(res_signed, x.dtype)
|
||
else:
|
||
return tf.math.negative(x)
|
||
|
||
tf_impl[lax.neg_p] = _neg
|
||
|
||
|
||
def _sign(x: TfVal) -> TfVal:
|
||
if x.dtype.is_unsigned:
|
||
# TF and XLA do not support tf.math.sign for unsigned types.
|
||
return tf.where(
|
||
tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype),
|
||
tf.constant(1, dtype=x.dtype))
|
||
else:
|
||
return tf.math.sign(x)
|
||
|
||
|
||
tf_impl[lax.sign_p] = _sign
|
||
tf_impl[lax.floor_p] = tf.math.floor
|
||
tf_impl[lax.ceil_p] = tf.math.ceil
|
||
|
||
|
||
def _round(operand, *, rounding_method,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
|
||
# JAX uses a single HLO op Round here
|
||
sign = _sign(operand)
|
||
operand *= sign
|
||
floor = tf.math.floor(operand)
|
||
operand -= floor
|
||
cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype))
|
||
return sign * (
|
||
tf.where(cond, tf.constant(np.array(1), operand.dtype),
|
||
tf.math.round(operand)) + floor)
|
||
else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
||
rounding_fun = _convert_jax_impl(
|
||
lax_internal._round_to_nearest_even, multiple_results=False)
|
||
return rounding_fun(operand, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
tf_impl_with_avals[lax.round_p] = _round
|
||
tf_impl[lax.nextafter_p] = tf.math.nextafter
|
||
|
||
|
||
def _population_count(x):
|
||
orig_dtype = x.dtype
|
||
return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype)
|
||
|
||
|
||
tf_impl[lax.population_count_p] = _population_count
|
||
tf_impl[lax.is_finite_p] = tf.math.is_finite
|
||
|
||
|
||
def _abs(x: TfVal) -> TfVal:
|
||
# TF and XLA do not support tf.math.abs for unsigned types.
|
||
return tf.math.abs(x) if not x.dtype.is_unsigned else x
|
||
|
||
|
||
tf_impl[lax.abs_p] = _abs
|
||
tf_impl[lax.pow_p] = tf.math.pow
|
||
|
||
|
||
def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Follows the implementation in lax._integer_pow_translation_rule
|
||
if y == 0:
|
||
return tf.broadcast_to(
|
||
tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape))
|
||
is_reciprocal = y < 0
|
||
if is_reciprocal:
|
||
y = -y
|
||
acc = None
|
||
while y > 0:
|
||
if y & 1:
|
||
acc = x if acc is None else tf.math.multiply(acc, x)
|
||
y >>= 1
|
||
if y > 0:
|
||
x = tf.math.multiply(x, x)
|
||
return tf.math.reciprocal(acc) if is_reciprocal else acc
|
||
|
||
|
||
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
|
||
tf_impl[lax.exp_p] = tf.math.exp
|
||
tf_impl[lax.expm1_p] = tf.math.expm1
|
||
tf_impl[lax.log_p] = tf.math.log
|
||
tf_impl[lax.log1p_p] = tf.math.log1p
|
||
tf_impl[lax.tan_p] = tf.math.tan
|
||
tf_impl[lax.tanh_p] = tf.math.tanh
|
||
tf_impl[lax.sin_p] = tf.math.sin
|
||
tf_impl[lax.sinh_p] = tf.math.sinh
|
||
tf_impl[lax.cos_p] = tf.math.cos
|
||
tf_impl[lax.cosh_p] = tf.math.cosh
|
||
tf_impl_with_avals[lax.acos_p] = _convert_jax_impl(
|
||
lax_internal.acos_impl, multiple_results=False)
|
||
tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
|
||
lax_internal.asin_impl, multiple_results=False)
|
||
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
|
||
lax_internal.atan_impl, multiple_results=False)
|
||
|
||
def _atan2(y, x, **kwargs):
|
||
if x.dtype.is_complex or y.dtype.is_complex:
|
||
complex_component_dtype = {
|
||
tf.complex64: tf.float32,
|
||
tf.complex128: tf.float64
|
||
}.get(y.dtype)
|
||
zero = tf.constant(0, complex_component_dtype)
|
||
one = tf.constant(1, complex_component_dtype)
|
||
i = tf.complex(zero, one)
|
||
return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y))
|
||
else:
|
||
return tf.math.atan2(y, x)
|
||
|
||
|
||
tf_impl[lax.atan2_p] = _atan2
|
||
tf_impl[lax.acosh_p] = tf.math.acosh
|
||
tf_impl[lax.atanh_p] = tf.math.atanh
|
||
tf_impl[lax.asinh_p] = tf.math.asinh
|
||
|
||
tf_impl[lax.sqrt_p] = tf.math.sqrt
|
||
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
|
||
|
||
def _cbrt(x):
|
||
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)
|
||
|
||
tf_impl[lax.cbrt_p] = _cbrt
|
||
|
||
tf_impl[lax.lgamma_p] = tf.math.lgamma
|
||
tf_impl[lax.digamma_p] = tf.math.digamma
|
||
tf_impl[lax.igamma_p] = tf.math.igamma
|
||
tf_impl[lax.igammac_p] = tf.math.igammac
|
||
tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc
|
||
tf_impl[lax.erf_p] = tf.math.erf
|
||
tf_impl[lax.erfc_p] = tf.math.erfc
|
||
tf_impl[lax.erf_inv_p] = tf.math.erfinv
|
||
tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
|
||
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
|
||
|
||
tf_impl[lax.complex_p] = tf.complex
|
||
|
||
|
||
def _conj(x, **kwargs):
|
||
# The only dtypes that are allowed are: float32, float64, complex64, and
|
||
# complex128.
|
||
if x.dtype == tf.float32:
|
||
return tf.cast(x, tf.complex64)
|
||
elif x.dtype == tf.float64:
|
||
return tf.cast(x, tf.complex128)
|
||
else:
|
||
return tf.math.conj(x)
|
||
|
||
|
||
tf_impl[lax.conj_p] = _conj
|
||
tf_impl[lax.real_p] = tf.math.real
|
||
tf_impl[lax.imag_p] = tf.math.imag
|
||
|
||
tf_impl[lax.add_p] = _add
|
||
tf_impl[lax.sub_p] = tf.math.subtract
|
||
tf_impl[lax.mul_p] = tf.math.multiply
|
||
|
||
|
||
def _iota(*, dtype, shape, dimension):
|
||
dtype = _to_tf_dtype(dtype)
|
||
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
|
||
# TODO(mattjj, necula): improve tf.range dtype handling
|
||
shape_tf = _eval_shape(shape)
|
||
vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32)
|
||
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
|
||
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype)
|
||
|
||
|
||
tf_impl[lax.iota_p] = _iota
|
||
|
||
|
||
def _div(lhs, rhs):
|
||
if lhs.dtype.is_integer:
|
||
quotient = tf.math.floordiv(lhs, rhs)
|
||
select = tf.math.logical_and(
|
||
tf.not_equal(_sign(lhs), _sign(rhs)),
|
||
tf.not_equal(tf.math.floormod(lhs, rhs), 0))
|
||
return tf.where(select, quotient + 1, quotient)
|
||
else:
|
||
return tf.math.truediv(lhs, rhs)
|
||
|
||
|
||
def _rem(lhs, rhs):
|
||
return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs))
|
||
|
||
|
||
tf_impl[lax.div_p] = _div
|
||
tf_impl[lax.rem_p] = _rem
|
||
|
||
|
||
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,) -> TfVal:
|
||
# For complex numbers use lexicographic ordering, like JAX
|
||
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
|
||
return _convert_jax_impl(
|
||
partial(lax_internal._minmax_complex_lowering,
|
||
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
||
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
elif x.dtype.as_numpy_dtype == np.bool_:
|
||
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
|
||
else:
|
||
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
||
|
||
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
|
||
# For reducers we will need min/max for scalars only. In that case we
|
||
# can construct the AbstractValues outselves, even in the presence of
|
||
# shape polymorphism.
|
||
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
|
||
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
|
||
return _minmax(x, y, is_min=is_min,
|
||
_in_avals=[aval, aval], _out_aval=aval)
|
||
|
||
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
|
||
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
|
||
|
||
# Map from TF signed types to TF unsigned types.
|
||
_SIGNED_TO_UNSIGNED_TABLE = {
|
||
tf.int8: tf.uint8,
|
||
tf.int16: tf.uint16,
|
||
tf.int32: tf.uint32,
|
||
tf.int64: tf.uint64,
|
||
}
|
||
|
||
# Map from TF unsigned types to TF signed types.
|
||
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}
|
||
|
||
|
||
# Note: Bitwise operations only yield identical results on unsigned integers!
|
||
# pylint: disable=protected-access
|
||
def _shift_right_arithmetic_raw(x, y):
|
||
if x.dtype.is_unsigned:
|
||
assert x.dtype == y.dtype
|
||
orig_dtype = x.dtype
|
||
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
|
||
x = tf.cast(x, signed_dtype)
|
||
y = tf.cast(y, signed_dtype)
|
||
res = tf.bitwise.right_shift(x, y)
|
||
return tf.cast(res, orig_dtype)
|
||
else:
|
||
return tf.bitwise.right_shift(x, y)
|
||
|
||
|
||
def _shift_right_arithmetic(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA
|
||
# semantics to return the shift by the max value (x_bits - 1).
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
x_bits = 8 * x.dtype.size
|
||
clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1)
|
||
return _shift_right_arithmetic_raw(x, clamp_y)
|
||
|
||
|
||
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic
|
||
|
||
|
||
def _shift_right_logical_raw(x, y):
|
||
if x.dtype.is_unsigned:
|
||
return tf.bitwise.right_shift(x, y)
|
||
else:
|
||
assert x.dtype == y.dtype
|
||
orig_dtype = x.dtype
|
||
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
|
||
x = tf.cast(x, unsigned_dtype)
|
||
y = tf.cast(y, unsigned_dtype)
|
||
res = tf.bitwise.right_shift(x, y)
|
||
return tf.cast(res, orig_dtype)
|
||
|
||
|
||
def _shift_right_logical(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA semantics
|
||
# to return 0.
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
return tf.where(
|
||
_shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x))
|
||
|
||
|
||
tf_impl[lax.shift_right_logical_p] = _shift_right_logical
|
||
|
||
|
||
def _shift_left(x, y):
|
||
# TF shift is "implementation defined" if the shift amount is negative
|
||
# or larger or equal to the size of the value. We implement the XLA semantics
|
||
# to return 0.
|
||
# TODO: it is likely better to add XlaOps for shifts
|
||
return tf.where(
|
||
_shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x))
|
||
|
||
|
||
tf_impl[lax.shift_left_p] = _shift_left
|
||
|
||
|
||
def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal:
|
||
# Return the TF expression for when y is within bounds (0 <= y < |x|)
|
||
x_bits = 8 * x.dtype.size
|
||
# TF does not have comparisons for uint16 and uint32 (despite what the
|
||
# documentation says)
|
||
y_comp = tf.cast(
|
||
y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y
|
||
y_lt_x_bits = tf.math.less(y_comp, x_bits)
|
||
y_ge_0 = tf.math.greater_equal(y_comp, 0)
|
||
return tf.logical_and(y_lt_x_bits, y_ge_0)
|
||
|
||
|
||
def _not(x):
|
||
"""Computes bitwise not with support for booleans.
|
||
|
||
Numpy and JAX support bitwise not for booleans by applying a logical not!
|
||
This means that applying bitwise_not yields an unexpected result:
|
||
jnp.bitwise_not(jnp.array([True, False]))
|
||
>> DeviceArray([False, True], dtype=bool)
|
||
|
||
if you assume that booleans are simply casted to integers.
|
||
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
|
||
>> DeviceArray([True, True], dtype=bool)
|
||
"""
|
||
if x.dtype == tf.bool:
|
||
return tf.logical_not(x)
|
||
else:
|
||
return tf.bitwise.invert(x)
|
||
|
||
|
||
tf_impl[lax.not_p] = _not
|
||
|
||
|
||
def handle_boolean_args(f, argnums: Sequence[int], boolean_f=None):
|
||
"""Computes functions with some bool args and bool results using int8.
|
||
|
||
This is needed because some TF ops do not work for bool args, e.g.,
|
||
inequalities, min/max.
|
||
|
||
Args:
|
||
f: a TF callable to wrap. It will be called with non-boolean arguments.
|
||
argnums: the positional arguments that may be booleans.
|
||
boolean_f: [Optional] a TF callable compatible with boolean
|
||
arguments.
|
||
|
||
Returns: a TF callable that can take a mix of boolean positional arguments
|
||
(in the positions specified by `argnums`) and some non-boolean positional
|
||
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
|
||
it calls `boolean_f` if defined. Otherwise, casts the boolean
|
||
arguments to `int8`, calls `f`, then casts the result to `bool`.
|
||
"""
|
||
argnums = tf.nest.flatten(argnums)
|
||
|
||
def wrapper(*args: TfVal, **kwargs):
|
||
argnum_types = {args[i].dtype for i in argnums}
|
||
if tf.bool not in argnum_types:
|
||
return f(*args, **kwargs)
|
||
else:
|
||
# All argnums should be boolean
|
||
assert len(argnum_types) == 1, argnum_types
|
||
if boolean_f != None:
|
||
return boolean_f(*args, **kwargs)
|
||
else:
|
||
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
|
||
for i, a in enumerate(args)]
|
||
if "_in_avals" in kwargs:
|
||
|
||
def cast_aval(aval):
|
||
assert aval.dtype == np.bool_
|
||
return core.ShapedArray(aval.shape, np.int8)
|
||
|
||
_in_avals_cast = [
|
||
cast_aval(aval) if i in argnums else aval
|
||
for i, aval in enumerate(kwargs["_in_avals"])
|
||
]
|
||
_out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"])
|
||
kwargs = dict(
|
||
kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast)
|
||
out = f(*args_cast, **kwargs)
|
||
return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out)
|
||
|
||
return wrapper
|
||
|
||
|
||
tf_impl[lax.or_p] = handle_boolean_args(tf.bitwise.bitwise_or, argnums=(0, 1), boolean_f=tf.logical_or)
|
||
tf_impl[lax.and_p] = handle_boolean_args(tf.bitwise.bitwise_and, argnums=(0, 1), boolean_f=tf.logical_and)
|
||
tf_impl[lax.xor_p] = handle_boolean_args(tf.bitwise.bitwise_xor, argnums=(0, 1), boolean_f=tf.math.logical_xor)
|
||
|
||
tf_impl[lax.eq_p] = tf.math.equal
|
||
tf_impl[lax.ne_p] = tf.math.not_equal
|
||
|
||
boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T
|
||
boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T
|
||
boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T
|
||
boolean_less_or_equal = lambda x, y: tf.logical_not(boolean_greater(x,y)) # All cases except T,F
|
||
|
||
tf_impl[lax.gt_p] = handle_boolean_args(tf.math.greater, argnums=(0, 1), boolean_f=boolean_greater)
|
||
tf_impl[lax.lt_p] = handle_boolean_args(tf.math.less, argnums=(0, 1), boolean_f=boolean_less)
|
||
tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal)
|
||
tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal)
|
||
|
||
tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky
|
||
|
||
|
||
def _convert_element_type(operand, *, new_dtype, weak_type=False):
|
||
old_dtype = operand.dtype.as_numpy_dtype
|
||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||
operand = tf.math.real(operand)
|
||
if (dtypes.issubdtype(old_dtype, np.floating) and
|
||
not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype(
|
||
new_dtype, np.complexfloating) or new_dtype == np.bool_)):
|
||
sign = _sign(operand)
|
||
operand = sign * tf.math.floor(sign * operand)
|
||
return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype))
|
||
|
||
|
||
tf_impl[lax.convert_element_type_p] = _convert_element_type
|
||
|
||
|
||
def _bitcast_convert_type(operand, new_dtype):
|
||
if operand.dtype == new_dtype:
|
||
return operand
|
||
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
|
||
|
||
|
||
tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
|
||
|
||
|
||
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
|
||
# The below permits mirroring the behavior of JAX when maxval < minval
|
||
op_shape_tf_val = _eval_shape(_in_avals[1].shape)
|
||
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
|
||
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
|
||
return tf.clip_by_value(operand, minval, maxval)
|
||
|
||
|
||
tf_impl_with_avals[lax.clamp_p] = _clamp
|
||
|
||
|
||
def _concatenate(*operands, dimension):
|
||
return tf.concat(operands, axis=dimension)
|
||
|
||
|
||
tf_impl[lax.concatenate_p] = _concatenate
|
||
|
||
|
||
def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
|
||
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
|
||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||
proto = xla_data_pb2.ConvolutionDimensionNumbers()
|
||
proto.input_batch_dimension = lhs_spec[0]
|
||
proto.input_feature_dimension = lhs_spec[1]
|
||
proto.output_batch_dimension = out_spec[0]
|
||
proto.output_feature_dimension = out_spec[1]
|
||
proto.kernel_output_feature_dimension = rhs_spec[0]
|
||
proto.kernel_input_feature_dimension = rhs_spec[1]
|
||
proto.input_spatial_dimensions.extend(lhs_spec[2:])
|
||
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
|
||
proto.output_spatial_dimensions.extend(out_spec[2:])
|
||
return proto
|
||
|
||
|
||
def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
|
||
PrecisionType]]):
|
||
"""Convert an integer to an XLA.PrecisionConfig."""
|
||
if precision is None:
|
||
return None
|
||
|
||
proto = xla_data_pb2.PrecisionConfig()
|
||
proto.operand_precision.append(int(precision[0]))
|
||
proto.operand_precision.append(int(precision[1]))
|
||
return proto
|
||
|
||
|
||
def _conv_general_dilated(lhs, rhs, *,
|
||
window_strides, padding, lhs_dilation,
|
||
rhs_dilation,
|
||
dimension_numbers: lax.ConvDimensionNumbers,
|
||
feature_group_count: int,
|
||
batch_group_count: int,
|
||
lhs_shape: Sequence[int],
|
||
rhs_shape: Sequence[int],
|
||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||
preferred_element_type: Optional[DType],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||
out_tf_shape = _aval_to_tf_shape(_out_aval)
|
||
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
|
||
precision_config_proto = _precision_config_proto(precision)
|
||
|
||
def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]):
|
||
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
|
||
if tf_version >= (2, 8):
|
||
# TODO(necula): remove when 2.8.0 is the stable TF version (and supports
|
||
# batch_group_count.
|
||
out = tfxla.conv(
|
||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||
dnums_proto,
|
||
feature_group_count=feature_group_count,
|
||
batch_group_count=batch_group_count,
|
||
precision_config=precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
else:
|
||
if batch_group_count != 1:
|
||
raise ValueError(
|
||
"The batch_group_count parameter for conv requires TF version "
|
||
"at least 2.8.0. You may want to use tf-nightly.")
|
||
out = tfxla.conv(
|
||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||
dnums_proto,
|
||
feature_group_count=feature_group_count,
|
||
precision_config=precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
# TODO: implement shape inference for XlaConv
|
||
out.set_shape(out_tf_shape)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
# Follow the lowering for complex convolutions from
|
||
# lax._conv_general_dilated_translation. We can use the same conversion on all
|
||
# platforms because on XLA:TPU the compiler does the same as a rewrite.
|
||
preferred_float_et: Optional[Any]
|
||
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
|
||
if preferred_element_type is not None:
|
||
# Convert complex dtype to types used for real and imaginary parts
|
||
assert np.issubdtype(preferred_element_type, np.complexfloating)
|
||
preferred_float_et = (
|
||
np.float64 if preferred_element_type == np.complex128 else np.float32)
|
||
else:
|
||
preferred_float_et = None
|
||
lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs)
|
||
rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs)
|
||
k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et)
|
||
k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real),
|
||
preferred_float_et)
|
||
k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et)
|
||
return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2))
|
||
else:
|
||
return gen_conv(lhs, rhs, preferred_element_type)
|
||
|
||
|
||
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||
|
||
|
||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||
preferred_element_type: Optional[DType],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
||
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||
precision_config_proto = _precision_config_proto(precision)
|
||
res = tfxla.dot_general(
|
||
lhs,
|
||
rhs,
|
||
dnums_proto,
|
||
precision_config_proto,
|
||
preferred_element_type=preferred_element_type,
|
||
use_v2=True)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
res = tf.stop_gradient(res) # See #7839
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.dot_general_p] = _dot_general
|
||
|
||
|
||
def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# for i in range(len(operand.shape)):
|
||
# result.shape[bcast_dims[i]] <- operand.shape[i]
|
||
# bcast_dims must be strictly increasing.
|
||
# len(bcast_dims) == len(operand.shape)
|
||
op_shape = _in_avals[0].shape
|
||
add_1s_shape = [1] * len(shape)
|
||
for i, broadcast_dim_i in enumerate(broadcast_dimensions):
|
||
add_1s_shape[broadcast_dim_i] = op_shape[i]
|
||
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape))
|
||
return tf.broadcast_to(with_1s, _eval_shape(shape))
|
||
|
||
|
||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||
|
||
|
||
def _reshape(operand, *, new_sizes, dimensions):
|
||
if dimensions is None:
|
||
dimensions = tf.range(tf.rank(operand))
|
||
new_sizes_tf = _eval_shape(new_sizes)
|
||
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
|
||
|
||
|
||
tf_impl[lax.reshape_p] = _reshape
|
||
|
||
|
||
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
|
||
op_shape = _in_avals[0].shape
|
||
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
|
||
new_shape_tf = _eval_shape(new_shape)
|
||
return tf.reshape(operand, new_shape_tf)
|
||
|
||
|
||
tf_impl_with_avals[lax.squeeze_p] = _squeeze
|
||
|
||
|
||
def _pad(operand, padding_value, *, padding_config,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
low, high, interior = util.unzip3(padding_config)
|
||
out = tfxla.pad(operand, padding_value, low, high, interior)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.pad_p] = _pad
|
||
|
||
|
||
def _rev(operand, *, dimensions):
|
||
return tf.reverse(operand, dimensions)
|
||
|
||
|
||
tf_impl[lax.rev_p] = _rev
|
||
|
||
|
||
def _where(which, *cases):
|
||
if which.dtype == tf.bool:
|
||
assert len(cases) <= 2
|
||
return cases if len(cases) == 1 else tf.where(which, cases[1], cases[0])
|
||
|
||
def _select(offset, cases):
|
||
assert len(cases) > 0
|
||
if len(cases) == 1:
|
||
return cases[0]
|
||
mid = len(cases) // 2
|
||
return tf.where(tf.less(which, offset + mid),
|
||
_select(offset, cases[:mid]),
|
||
_select(mid, cases[mid:]))
|
||
|
||
return _select(0, cases)
|
||
|
||
|
||
tf_impl[lax.select_n_p] = _where
|
||
|
||
|
||
def _transpose(operand, *, permutation):
|
||
return tf.transpose(operand, perm=permutation)
|
||
|
||
|
||
tf_impl[lax.transpose_p] = _transpose
|
||
|
||
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
||
|
||
# reduce_sum and reduce_prod are not supported for bool
|
||
tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum)
|
||
tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod)
|
||
tf_impl[lax.reduce_max_p] = handle_boolean_args(
|
||
axes_to_axis(tf.reduce_max), argnums=[0],
|
||
boolean_f=axes_to_axis(tf.reduce_any)) # Max is T if any one is T
|
||
tf_impl[lax.reduce_min_p] = handle_boolean_args(
|
||
axes_to_axis(tf.reduce_min), argnums=[0],
|
||
boolean_f=axes_to_axis(tf.reduce_all)) # Min is F if not all are T
|
||
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||
|
||
|
||
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
|
||
index_dtype: DType,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Follow the JAX implementation, using a XlaReduce with a custom comparator
|
||
if is_min:
|
||
extra_name_stack = "argmin"
|
||
value_comparator = lax.lt
|
||
get_identity = lax_internal._get_min_identity
|
||
else:
|
||
extra_name_stack = "argmax"
|
||
value_comparator = lax.gt
|
||
get_identity = lax_internal._get_max_identity
|
||
|
||
res = _convert_jax_impl(
|
||
partial(lax_internal._compute_argminmax, value_comparator, get_identity),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)(
|
||
operand,
|
||
index_dtype=index_dtype,
|
||
axes=axes,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True)
|
||
tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False)
|
||
|
||
|
||
_add_fn = tf.function(_add, autograph=False)
|
||
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
|
||
|
||
|
||
def _select_and_gather_add(
|
||
tangents: TfVal, operand: TfVal, select_prim: core.Primitive,
|
||
window_dimensions: Sequence[int], window_strides: Sequence[int],
|
||
base_dilation: Sequence[int], window_dilation: Sequence[int],
|
||
padding: Sequence[Tuple[int, int]], _in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
# Note: this function follows the pattern in
|
||
# jax.lax._select_and_gather_add_translation.
|
||
dtype = operand.dtype
|
||
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits
|
||
|
||
# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
|
||
# we thus intend to let the code throw a different exception on this platform.
|
||
max_bits = 64
|
||
|
||
assert nbits <= max_bits
|
||
double_word_reduction = nbits * 2 <= max_bits
|
||
|
||
const = lambda dtype, x: tf.constant(np.array(x), dtype)
|
||
|
||
if double_word_reduction:
|
||
word_dtype = lax_internal._UINT_DTYPES[nbits]
|
||
double_word_dtype = lax_internal._UINT_DTYPES[nbits * 2]
|
||
|
||
# Packs two values into a tuple.
|
||
def pack(a, b):
|
||
a = _bitcast_convert_type(a, word_dtype)
|
||
b = _bitcast_convert_type(b, word_dtype)
|
||
a = _convert_element_type(a, new_dtype=double_word_dtype)
|
||
b = _convert_element_type(b, new_dtype=double_word_dtype)
|
||
a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits))
|
||
return tf.bitwise.bitwise_or(a, b)
|
||
|
||
# Unpacks the first element of a tuple.
|
||
def fst(t):
|
||
assert t.dtype == double_word_dtype
|
||
st = _shift_right_logical(t, const(double_word_dtype, nbits))
|
||
return _bitcast_convert_type(
|
||
_convert_element_type(st, new_dtype=word_dtype), dtype)
|
||
|
||
# Unpacks the second element of a tuple.
|
||
def snd(t):
|
||
return _bitcast_convert_type(
|
||
_convert_element_type(t, new_dtype=word_dtype), dtype)
|
||
|
||
else:
|
||
raise NotImplementedError(
|
||
f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits."
|
||
)
|
||
|
||
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
||
|
||
def reducer(x, y):
|
||
which = tf_impl[select_prim]
|
||
return tf_impl[lax.select_n_p](which(fst(x), fst(y)), y, x)
|
||
|
||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||
init_identity = lambda x: pack(const(dtype, init), const(dtype, 0))
|
||
|
||
out = _specialized_reduce_window(
|
||
reducer,
|
||
init_identity,
|
||
pack(operand, tangents),
|
||
window_dimensions=window_dimensions,
|
||
window_strides=window_strides,
|
||
padding=padding,
|
||
base_dilation=base_dilation,
|
||
window_dilation=window_dilation,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
return snd(out)
|
||
|
||
|
||
tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
|
||
|
||
|
||
def _get_shape_from_tensor_or_array(x):
|
||
if isinstance(x.shape, tf.TensorShape):
|
||
return tuple(x.shape.as_list())
|
||
return tuple(x.shape)
|
||
|
||
|
||
def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||
window_strides, padding, base_dilation,
|
||
window_dilation, _in_avals, _out_aval):
|
||
o_spec = tf.TensorSpec((), dtype=operand.dtype)
|
||
reducer_fn = tf.function(
|
||
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
|
||
|
||
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
|
||
init_val = tf.constant(init_val, operand.dtype)
|
||
out = tfxla.reduce_window(
|
||
operand,
|
||
init_val,
|
||
reducer_fn,
|
||
window_dimensions,
|
||
window_strides,
|
||
base_dilations=base_dilation,
|
||
window_dilations=window_dilation,
|
||
padding=padding)
|
||
# TODO: implement shape inference for XlaReduceWindow
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
||
window_strides, padding, base_dilation, window_dilation,
|
||
_in_avals, _out_aval):
|
||
"""TensorFlow implementation of reduce_window.
|
||
|
||
Args:
|
||
operands: N dimensional arrays containing elements of type T
|
||
init_values: starting values of the reduction
|
||
jaxpr: the jaxpr corresponding to the reduction function
|
||
consts: the constants associated with jaxpr.
|
||
window_dimensions: array of integers for window dimension values
|
||
window_strides: array of integers for window stride values
|
||
padding: array of pairs of integers for padding values
|
||
base_dilation: array of integers for base dilation values
|
||
window_dilation: array of integers for window dilation values
|
||
|
||
Returns:
|
||
The reduced operand.
|
||
"""
|
||
assert len(consts) == 0, "Reduction computation cannot have constants"
|
||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||
|
||
if len(operands) != 1:
|
||
raise NotImplementedError("jax2tf does not support variadic reduce_window")
|
||
|
||
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
|
||
return res
|
||
|
||
return (_common_reduce_window(operands[0], init_values[0], reducer,
|
||
window_dimensions, window_strides, padding,
|
||
base_dilation, window_dilation, _in_avals,
|
||
_out_aval[0]),)
|
||
|
||
|
||
|
||
def _specialized_reduce_window(reducer,
|
||
identity,
|
||
operand,
|
||
*,
|
||
window_dimensions,
|
||
window_strides,
|
||
padding,
|
||
base_dilation,
|
||
window_dilation,
|
||
_in_avals,
|
||
_out_aval,
|
||
name=None):
|
||
"""Wraps the TensorFlow reduce window operation based on a reducer and an
|
||
|
||
identity function defining the initial value of the reduction depending on
|
||
the dtype of the operand.
|
||
|
||
Args:
|
||
reducer: reduction function of type TfVal -> TfVal -> TfVal
|
||
identity: function that takes a TensorFlow dtype as a parameter and returns
|
||
the starting value of the reduction.
|
||
operand: N dimensional array containing elements of type T
|
||
window_dimensions: array of integers for window dimension values
|
||
window_strides: array of integers for window stride values
|
||
padding: array of pairs of integers for padding values
|
||
base_dilation: array of integers for base dilation values
|
||
window_dilation: array of integers for window dilation values
|
||
name: the name of the specialized reduce window primitive for which this
|
||
conversion function is called. This information may help to choose a
|
||
different conversion path (optional)
|
||
|
||
Returns:
|
||
The reduced operand.
|
||
"""
|
||
return _common_reduce_window(operand, identity(operand.dtype), reducer,
|
||
window_dimensions, window_strides, padding,
|
||
base_dilation, window_dilation, _in_avals,
|
||
_out_aval)
|
||
|
||
|
||
def _get_max_identity(tf_dtype):
|
||
numpy_tf_dtype = tf_dtype.as_numpy_dtype
|
||
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
|
||
return numpy_tf_dtype(-np.inf)
|
||
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
|
||
return dtypes.iinfo(numpy_tf_dtype).min
|
||
else:
|
||
assert dtypes.issubdtype(
|
||
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity")
|
||
return False
|
||
|
||
|
||
def _get_min_identity(tf_dtype):
|
||
numpy_tf_dtype = tf_dtype.as_numpy_dtype
|
||
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
|
||
return numpy_tf_dtype(np.inf)
|
||
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
|
||
return dtypes.iinfo(numpy_tf_dtype).max
|
||
else:
|
||
assert dtypes.issubdtype(
|
||
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity")
|
||
return True
|
||
|
||
|
||
# pylint: disable=protected-access
|
||
tf_impl_with_avals[lax.reduce_window_sum_p] = (
|
||
partial(_specialized_reduce_window, _add, lambda x: 0,
|
||
name="reduce_window_sum"))
|
||
tf_impl_with_avals[lax.reduce_window_min_p] = (
|
||
partial(_specialized_reduce_window,
|
||
partial(_minmax_scalar, is_min=True),
|
||
_get_min_identity,
|
||
name="reduce_window_min"))
|
||
tf_impl_with_avals[lax.reduce_window_max_p] = (
|
||
partial(_specialized_reduce_window,
|
||
partial(_minmax_scalar, is_min=False),
|
||
_get_max_identity,
|
||
name="reduce_window_max"))
|
||
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
|
||
# pylint: enable=protected-access
|
||
|
||
def _reduce(*operands: TfVal,
|
||
computation: Callable,
|
||
jaxpr: core.Jaxpr,
|
||
consts: Sequence[Any],
|
||
dimensions: Sequence[int],
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray) -> Sequence[TfVal]:
|
||
del computation
|
||
assert not consts
|
||
assert len(operands) % 2 == 0
|
||
# operands: op1, op2, ..., init_val1, init_val2, ...
|
||
# reducer takes op1[i], op2[i], ..., init_val1, init_val2, ...
|
||
nr_operands = len(operands) // 2
|
||
init_vals = operands[nr_operands:]
|
||
operands = operands[0:nr_operands]
|
||
|
||
reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2)
|
||
|
||
def reducer_computation(*args: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||
res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None)
|
||
return res
|
||
|
||
xla_reducer_computation = (
|
||
tf.function(reducer_computation,
|
||
autograph=False).get_concrete_function(*reducer_arg_spec))
|
||
|
||
outs = tfxla.variadic_reduce(operands, init_vals,
|
||
dimensions_to_reduce=dimensions,
|
||
reducer=xla_reducer_computation)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
outs = tuple(tf.stop_gradient(out) for out in outs) # See #7839
|
||
return outs
|
||
|
||
tf_impl_with_avals[lax.reduce_p] = _reduce
|
||
|
||
|
||
# We use lax.cumred_tpu_impl to convert cummax,
|
||
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
|
||
# O(n^2) on other backends. This may be implemented using associative_scan
|
||
# instead to favor different backends.
|
||
def _cumred(lax_reduce_fn: Callable,
|
||
lax_reduce_window_fn: Callable,
|
||
extra_name_stack: str):
|
||
if config.jax2tf_associative_scan_reductions:
|
||
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
|
||
lax_reduce_fn),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)
|
||
else:
|
||
return _convert_jax_impl(partial(lax_control_flow.cumred_tpu_impl,
|
||
lax_reduce_window_fn),
|
||
multiple_results=False,
|
||
extra_name_stack=extra_name_stack)
|
||
|
||
|
||
tf_impl_with_avals[lax.cummax_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_max,
|
||
lax_reduce_fn=lax.max,
|
||
extra_name_stack="cummax")
|
||
tf_impl_with_avals[lax.cummin_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_min,
|
||
lax_reduce_fn=lax.min,
|
||
extra_name_stack="cummin")
|
||
tf_impl_with_avals[lax.cumsum_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_sum,
|
||
lax_reduce_fn=lax.add,
|
||
extra_name_stack="cumsum")
|
||
tf_impl_with_avals[lax.cumprod_p] = _cumred(
|
||
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_prod,
|
||
lax_reduce_fn=lax.mul,
|
||
extra_name_stack="cumprod")
|
||
|
||
|
||
def _select_and_scatter(operand, source, init_value, select_jaxpr,
|
||
select_consts, scatter_jaxpr, scatter_consts,
|
||
window_dimensions, window_strides, padding):
|
||
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")
|
||
|
||
|
||
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
|
||
|
||
|
||
@partial(handle_boolean_args, argnums=(0, 1))
|
||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||
window_strides, padding, _in_avals, _out_aval):
|
||
init_value = tf.zeros((), operand.dtype)
|
||
select_fn = (
|
||
tf.function(tf_impl[select_prim], autograph=False).get_concrete_function(
|
||
init_value, init_value))
|
||
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
|
||
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
|
||
padding, source, init_value, select_fn,
|
||
scatter_fn)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
|
||
|
||
|
||
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
|
||
res = _convert_jax_impl(
|
||
partial(jax._src.prng._threefry2x32_lowering, use_rolled_loops=False),
|
||
multiple_results=True, extra_name_stack="threefry")(
|
||
*args, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[jax._src.prng.threefry2x32_p] = _threefry2x32_jax_impl
|
||
|
||
# Use the vmap implementation, otherwise on TPU the performance is really bad
|
||
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
|
||
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
|
||
partial(jax._src.random._gamma_impl, use_vmap=True),
|
||
multiple_results=False, extra_name_stack="random_gamma")
|
||
|
||
|
||
def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]:
|
||
is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32)
|
||
if is_uint32_key:
|
||
key = tf.reshape(key, (2, 2))
|
||
key = tfxla.bitcast_convert_type(key, _to_tf_dtype(jnp.uint64))
|
||
shape_tf = _eval_shape(shape)
|
||
# JAX uses XLA algorithm enums; tfxla uses tf.random.Algorithm
|
||
if algorithm == lax.RandomAlgorithm.RNG_THREE_FRY:
|
||
algorithm_tf = tf.random.Algorithm.THREEFRY
|
||
elif algorithm == lax.RandomAlgorithm.RNG_PHILOX:
|
||
algorithm_tf = tf.random.Algorithm.PHILOX
|
||
elif algorithm == lax.RandomAlgorithm.RNG_DEFAULT:
|
||
algorithm_tf = tf.random.Algorithm.AUTO_SELECT
|
||
else:
|
||
assert False
|
||
(new_key, res) = tfxla.rng_bit_generator(algorithm_tf.value, key, shape_tf,
|
||
dtype=_to_tf_dtype(dtype))
|
||
if is_uint32_key:
|
||
new_key = tfxla.bitcast_convert_type(new_key, _to_tf_dtype(jnp.uint32))
|
||
new_key = tf.reshape(new_key, (4,))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
# See #7839
|
||
new_key = tf.stop_gradient(new_key)
|
||
res = tf.stop_gradient(res)
|
||
return new_key, res
|
||
|
||
|
||
tf_impl[lax.rng_bit_generator_p] = _rng_bit_generator
|
||
|
||
|
||
def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal:
|
||
shape_tf = _eval_shape(shape)
|
||
return tf.random.uniform(shape_tf, minval=minval, maxval=maxval, dtype=minval.dtype)
|
||
|
||
tf_impl[lax.rng_uniform_p] = _rng_uniform
|
||
|
||
|
||
def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||
proto = xla_data_pb2.GatherDimensionNumbers()
|
||
proto.offset_dims.extend(dimension_numbers.offset_dims)
|
||
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
|
||
proto.start_index_map.extend(dimension_numbers.start_index_map)
|
||
assert indices_shape
|
||
proto.index_vector_dim = len(indices_shape) - 1
|
||
return proto
|
||
|
||
|
||
@partial(handle_boolean_args, argnums=[0])
|
||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shape,
|
||
indices_are_sorted, unique_indices, mode, fill_value,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
"""Tensorflow implementation of gather."""
|
||
if mode == lax.GatherScatterMode.FILL_OR_DROP:
|
||
gather_fill_fn = _convert_jax_impl(lax_slicing._gather_fill,
|
||
multiple_results=False)
|
||
return gather_fill_fn(
|
||
operand, start_indices, dimension_numbers=dimension_numbers,
|
||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
|
||
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
||
slice_sizes_tf = _eval_shape(slice_sizes)
|
||
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
|
||
indices_are_sorted)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.gather_p] = _gather
|
||
|
||
|
||
def _slice(operand, start_indices, limit_indices, strides, _in_avals,
|
||
_out_aval):
|
||
if strides is None:
|
||
strides = [1] * len(start_indices)
|
||
slices = tuple(
|
||
map(slice, _eval_shape(start_indices), _eval_shape(limit_indices),
|
||
_eval_shape(strides)))
|
||
out = operand[slices]
|
||
# TODO(b/184503314): improve shape inference for __getitem__
|
||
# E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1)
|
||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.slice_p] = _slice
|
||
|
||
|
||
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
start_indices = tf.stack(start_indices)
|
||
slice_sizes_tf = _eval_shape(slice_sizes)
|
||
|
||
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
res = tf.stop_gradient(res) # See #7839
|
||
return res
|
||
|
||
|
||
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
|
||
|
||
|
||
def _dynamic_update_slice(operand, update, *start_indices,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
out = tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice
|
||
|
||
|
||
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
|
||
proto = xla_data_pb2.ScatterDimensionNumbers()
|
||
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
|
||
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
|
||
proto.scatter_dims_to_operand_dims.extend(
|
||
dimension_numbers.scatter_dims_to_operand_dims)
|
||
assert indices_shape
|
||
proto.index_vector_dim = len(indices_shape) - 1
|
||
return proto
|
||
|
||
|
||
def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
|
||
dimension_numbers, indices_are_sorted, unique_indices, mode,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
del unique_indices
|
||
|
||
if mode == lax.GatherScatterMode.CLIP:
|
||
clip_fn = _convert_jax_impl(lax_slicing._clamp_scatter_indices,
|
||
multiple_results=False)
|
||
scatter_indices = clip_fn(
|
||
operand, scatter_indices, updates, dnums=dimension_numbers,
|
||
_in_avals=_in_avals, _out_aval=_in_avals[1])
|
||
|
||
assert len(update_consts) == 0, "Update computation cannot have constants"
|
||
|
||
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
|
||
|
||
def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||
closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
|
||
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
|
||
return res
|
||
|
||
o_spec = tf.TensorSpec((), dtype=operand.dtype)
|
||
xla_update_computation = (
|
||
tf.function(update_computation,
|
||
autograph=False).get_concrete_function(o_spec, o_spec))
|
||
out = tfxla.scatter(
|
||
operand,
|
||
scatter_indices,
|
||
updates,
|
||
xla_update_computation,
|
||
proto,
|
||
indices_are_sorted=indices_are_sorted)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
out = tf.stop_gradient(out) # See #7839
|
||
return out
|
||
|
||
|
||
tf_impl_with_avals[lax.scatter_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_min_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_max_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
|
||
tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
||
|
||
|
||
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||
linear: Sequence[bool]) -> Sequence[TfVal]:
|
||
del linear
|
||
# tf.cond needs lambdas with no arguments.
|
||
branches_tf = [
|
||
partial(_interpret_jaxpr, jaxpr, *operands,
|
||
# Same name stack as the XLA translation of cond_p
|
||
extra_name_stack=f"branch_{i}_fun")
|
||
for jaxpr in branches
|
||
for i, jaxpr in enumerate(branches)
|
||
]
|
||
if config.jax_experimental_name_stack:
|
||
# Same name stack as XLA translation of cond_p
|
||
branches_tf = list(map(source_info_util.extend_name_stack("cond"),
|
||
branches_tf))
|
||
return tf.switch_case(index, branches_tf)
|
||
|
||
|
||
tf_impl[lax.cond_p] = _cond
|
||
|
||
|
||
def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
|
||
body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
|
||
cond_consts, body_consts, init_carry = util.split_list(
|
||
args, [cond_nconsts, body_nconsts])
|
||
if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined]
|
||
# The conditional is not a scalar, this must be a batched while
|
||
return _batched_cond_while(
|
||
*args,
|
||
cond_nconsts=cond_nconsts,
|
||
cond_jaxpr=cond_jaxpr,
|
||
body_nconsts=body_nconsts,
|
||
body_jaxpr=body_jaxpr)
|
||
|
||
# The conditional must return a single value to TF
|
||
def cond_tf_func(*args: TfVal) -> TfVal:
|
||
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args,
|
||
# Same name stack as the XLA translation of while_p
|
||
extra_name_stack="while/cond")
|
||
return pred
|
||
|
||
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts,
|
||
extra_name_stack="while/body")
|
||
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
|
||
|
||
|
||
def _batched_cond_while(*args: TfVal, cond_nconsts: int,
|
||
cond_jaxpr: core.ClosedJaxpr, body_nconsts: int,
|
||
body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
|
||
"""Interprets a while_loop with a batched condition.
|
||
|
||
A batched while has a conditional that returns a tensor of booleans, and
|
||
a body that returns a list of tensors whose leading dimensions match those
|
||
of the conditional tensor.
|
||
|
||
We need to turn it into a while with scalar boolean conditional. We will
|
||
expand the loop carry to include a prefix with the current tensor boolean
|
||
condition. We prepend to the loop the first calculation of the tensor boolean
|
||
condition. The loop condition will use a "reduce_any" to calculate a scalar
|
||
boolean from the tensor boolean condition. The end of the loop body will
|
||
compute the new carry using a "tf.where", and we compute the new tensor
|
||
boolean condition.
|
||
"""
|
||
cond_consts, body_consts, init_carry = util.split_list(
|
||
args, [cond_nconsts, body_nconsts])
|
||
# Initial computation of batched condition
|
||
init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry,
|
||
extra_name_stack="while/body_pred")
|
||
|
||
def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal:
|
||
pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape))))
|
||
return pred
|
||
|
||
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
|
||
new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts,
|
||
*carry,
|
||
extra_name_stack="while/body")
|
||
# We repeat those carries for which the loop termination condition is false
|
||
def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
|
||
pred_b_bcast = _broadcast_in_dim(
|
||
pred_b,
|
||
shape=c_aval.shape, # a JAX shape
|
||
broadcast_dimensions=list(range(len(pred_b.shape))),
|
||
_in_avals=cond_jaxpr.out_avals,
|
||
_out_aval=core.ShapedArray(c_aval.shape, np.bool_))
|
||
return tf.where(pred_b_bcast, new_c, c)
|
||
|
||
selected_carry: Sequence[TfVal] = list(map(select_one_carry, new_carry, carry, body_jaxpr.out_avals))
|
||
next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry,
|
||
extra_name_stack="body_pred")
|
||
return (next_pred_b, *selected_carry)
|
||
|
||
_, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func,
|
||
(init_pred_b, *init_carry))
|
||
return res_carry
|
||
|
||
|
||
tf_impl[lax.while_p] = _while
|
||
|
||
# We use the scan impl rule to rewrite in terms of while.
|
||
tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
|
||
lax_control_flow._scan_impl,
|
||
extra_name_stack="scan")
|
||
|
||
tf_impl_with_avals[ad_checkpoint.remat_p] = \
|
||
_convert_jax_impl(partial(lax_control_flow.remat_impl,
|
||
# TODO: jax2tf cannot discriminate by platform
|
||
platform="cpu"),
|
||
multiple_results=True,
|
||
extra_name_stack="checkpoint")
|
||
|
||
# TODO: Remove once tensorflow is 2.10.0 everywhere.
|
||
if hasattr(tfxla, 'optimization_barrier'):
|
||
tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
|
||
|
||
def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
|
||
# Some types originally incompatible with tf.math.top_k can be promoted
|
||
# to a compatible type without loss of precision.
|
||
def promote_tf_dtype(tf_dtype):
|
||
if tf_dtype in [tf.bool, tf.uint8, tf.uint16]:
|
||
return tf.uint32
|
||
if tf_dtype in [tf.int8, tf.int16]:
|
||
return tf.int32
|
||
if tf_dtype is tf.float16:
|
||
return tf.float32
|
||
return None
|
||
|
||
conversion_dtype = promote_tf_dtype(operand.dtype)
|
||
if conversion_dtype:
|
||
values, indices = tf.math.top_k(
|
||
tf.dtypes.cast(operand, conversion_dtype), k=k, sorted=True)
|
||
return tf.dtypes.cast(values, operand.dtype), indices
|
||
else:
|
||
return tf.math.top_k(operand, k=k, sorted=True)
|
||
|
||
|
||
tf_impl[lax.top_k_p] = _top_k
|
||
|
||
|
||
def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int,
|
||
recall_target: float, is_max_k: bool,
|
||
reduction_input_size_override: int,
|
||
aggregate_to_topk: bool) -> Tuple[TfVal, TfVal]:
|
||
if is_max_k:
|
||
return tf.math.approx_max_k(operand, k, reduction_dimension, recall_target,
|
||
reduction_input_size_override,
|
||
aggregate_to_topk)
|
||
else:
|
||
return tf.math.approx_min_k(operand, k, reduction_dimension, recall_target,
|
||
reduction_input_size_override,
|
||
aggregate_to_topk)
|
||
|
||
|
||
tf_impl[lax.approx_top_k_p] = _approx_top_k
|
||
|
||
|
||
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||
num_keys: int) -> Tuple[TfVal, ...]:
|
||
assert 1 <= num_keys <= len(operands)
|
||
assert 0 <= dimension < len(
|
||
operands[0].shape
|
||
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
|
||
|
||
comparator_spec: List[tf.TensorSpec] = []
|
||
comparator_jax_in_avals: List[core.ShapedArray] = []
|
||
for op in operands:
|
||
o_spec = tf.TensorSpec((), dtype=op.dtype)
|
||
comparator_spec.extend([o_spec, o_spec])
|
||
o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype))
|
||
comparator_jax_in_avals.extend([o_aval, o_aval])
|
||
|
||
# Use the same comparator that JAX uses when compiling to XLA, to get the
|
||
# proper NaN/Inf total order, and the lexicographic ordering.
|
||
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
|
||
# corresponding to two scalars from operand[k].
|
||
def lexicographic_comparator(*tf_args: TfVal) -> TfVal:
|
||
return _convert_jax_impl(
|
||
lax_internal._sort_lt_comparator, multiple_results=False)(
|
||
*tf_args,
|
||
_in_avals=comparator_jax_in_avals,
|
||
_out_aval=core.ShapedArray((), np.bool_),
|
||
num_keys=num_keys)
|
||
|
||
xla_comparator_computation = (
|
||
tf.function(lexicographic_comparator,
|
||
autograph=False).get_concrete_function(*comparator_spec))
|
||
results = tfxla.variadic_sort(
|
||
operands,
|
||
dimension=dimension,
|
||
is_stable=is_stable,
|
||
comparator=xla_comparator_computation)
|
||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||
results = tuple(tf.stop_gradient(out) for out in results) # See #7839
|
||
return results
|
||
|
||
|
||
tf_impl[lax.sort_p] = _sort
|
||
|
||
|
||
def _fft(x, fft_type, fft_lengths):
|
||
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
|
||
if fft_type == IRFFT:
|
||
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
|
||
else:
|
||
expected_lengths = x.shape[-len(fft_lengths):]
|
||
if expected_lengths != fft_lengths:
|
||
raise NotImplementedError(
|
||
f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of "
|
||
f"array with shape={x.shape}.")
|
||
tf_funcs = {
|
||
FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
|
||
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
|
||
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
|
||
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
|
||
}
|
||
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
|
||
|
||
|
||
tf_impl[lax.fft_p] = _fft
|
||
|
||
|
||
def _qr(operand, full_matrices):
|
||
return tf.linalg.qr(operand, full_matrices=full_matrices)
|
||
|
||
|
||
tf_impl[lax.linalg.qr_p] = _qr
|
||
|
||
|
||
def _svd(operand, full_matrices, compute_uv):
|
||
result = tf.linalg.svd(operand, full_matrices, compute_uv)
|
||
if not compute_uv:
|
||
return result,
|
||
s, u, v = result
|
||
return s, u, tf.linalg.adjoint(v)
|
||
|
||
|
||
tf_impl[lax.linalg.svd_p] = _svd
|
||
|
||
|
||
def _eig(operand: TfVal, compute_left_eigenvectors: bool,
|
||
compute_right_eigenvectors: bool):
|
||
if compute_left_eigenvectors and compute_right_eigenvectors:
|
||
# TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to
|
||
# sort the left eigenvectors in the right order. The jax.numpy.linalg API
|
||
# suggests to me that left eigenvectors are anyway seldom used, so I
|
||
# think it is acceptable to leave as unimplemented for now.
|
||
msg = ("Conversion of eig is not implemented when both "
|
||
"compute_left_eigenvectors and compute_right_eigenvectors are set "
|
||
"to True.")
|
||
raise NotImplementedError(msg)
|
||
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
|
||
return tuple([tf.linalg.eigvals(operand)])
|
||
elif compute_right_eigenvectors:
|
||
return tuple(tf.linalg.eig(operand))
|
||
else: # compute_left_eigenvectors == True
|
||
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
|
||
wHH = tf.math.conj(wH)
|
||
return tuple([wHH, vl])
|
||
|
||
|
||
tf_impl[lax.linalg.eig_p] = _eig
|
||
|
||
|
||
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
|
||
_out_aval):
|
||
del sort_eigenvalues
|
||
if operand.shape[-1] == 0:
|
||
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
|
||
else:
|
||
if not lower:
|
||
operand = tf.linalg.adjoint(operand)
|
||
w, v = tf.linalg.eigh(operand)
|
||
cast_type = {
|
||
tf.complex64: tf.float32,
|
||
tf.complex128: tf.float64
|
||
}.get(operand.dtype)
|
||
if cast_type is not None:
|
||
w = tf.cast(w, cast_type)
|
||
return v, w
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.eigh_p] = _eigh
|
||
|
||
|
||
def _lu(operand: TfVal, _in_avals, _out_aval):
|
||
return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")(
|
||
operand, _in_avals=_in_avals, _out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.lu_p] = _lu
|
||
|
||
|
||
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
|
||
transpose_a: bool, conjugate_a: bool, unit_diagonal: bool,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray):
|
||
if unit_diagonal:
|
||
a_aval, _ = _in_avals
|
||
a_shape = _eval_shape(a_aval.shape)
|
||
a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype))
|
||
if not left_side:
|
||
rank = len(a.shape)
|
||
transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2]
|
||
a = tf.transpose(a, transpose_dimensions)
|
||
b = tf.transpose(b, transpose_dimensions)
|
||
lower = not lower
|
||
# adjoint == transpose for real dtypes, so special care need only be taken
|
||
# for complex types.
|
||
if a.dtype in [tf.complex64, tf.complex128]:
|
||
if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a):
|
||
a = tf.math.conj(a)
|
||
result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a)
|
||
if not left_side:
|
||
result = tf.transpose(result, transpose_dimensions)
|
||
return result
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.triangular_solve_p] = _triangular_solve
|
||
|
||
|
||
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
|
||
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl,
|
||
extra_name_stack="linear_solve")(
|
||
*args,
|
||
const_lengths=const_lengths,
|
||
jaxprs=jaxprs,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linear_solve_p] = _linear_solve
|
||
|
||
def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
|
||
return _convert_jax_impl(lax_linalg._tridiagonal_solve_jax,
|
||
multiple_results=False,
|
||
extra_name_stack="tridiagonal_solve")(
|
||
*args,
|
||
_in_avals=_in_avals,
|
||
_out_aval=_out_aval)
|
||
|
||
|
||
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
|
||
|
||
def _custom_jvp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
|
||
jvp_jaxpr_thunk: Callable,
|
||
num_consts: int) -> Sequence[TfVal]:
|
||
# TODO(necula): ensure that there is no AD transformation in scope
|
||
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_jvp")
|
||
|
||
|
||
tf_impl[custom_derivatives.custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr
|
||
|
||
|
||
def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
|
||
**_) -> Sequence[TfVal]:
|
||
# TODO(necula): ensure that there is no AD transformation in scope
|
||
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp")
|
||
|
||
|
||
tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr
|
||
|
||
|
||
def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
|
||
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
|
||
"function.")
|
||
|
||
|
||
tf_impl[ad.custom_lin_p] = _custom_lin
|
||
|
||
|
||
def split_to_logical_devices(tensor: TfVal,
|
||
partition_dimensions: pxla.PartitionsOrReplicated):
|
||
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
|
||
|
||
For jax2tf purposes we want to avoid needing to thread the `strategy` object
|
||
through the generated computation. It seems that the original function needs
|
||
the strategy object only for error checking, which we assume is done upstream
|
||
by JAX.
|
||
|
||
Args:
|
||
tensor: Input tensor to annotate.
|
||
partition_dimensions: A list of integers, with one integer per tensor
|
||
dimension, specifying in how many parts the dimension should be split. The
|
||
product of integers must equal the number of devices per replica.
|
||
use_sharding_op: whether to use a sharding op, or not.
|
||
|
||
Returns:
|
||
an annotated tensor.
|
||
"""
|
||
# TODO: this is only for sharded_jit. Either remove, or implement in terms
|
||
# of _shard_values.
|
||
if partition_dimensions is None:
|
||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||
num_partition_splits = np.prod(partition_dimensions)
|
||
tile_assignment = np.arange(num_partition_splits).reshape(
|
||
partition_dimensions)
|
||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||
|
||
|
||
def _shard_value(mesh: maps.Mesh,
|
||
val: TfVal,
|
||
aval: core.ShapedArray,
|
||
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
|
||
"""Apply sharding to a TfVal."""
|
||
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
|
||
aval, axis_resources, mesh)
|
||
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||
xla_sharding_proto: xla_data_pb2.OpSharding = (
|
||
xla_data_pb2.OpSharding(
|
||
type=int(sharding_proto.type),
|
||
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||
tile_assignment_devices=sharding_proto.tile_assignment_devices,
|
||
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
|
||
last_tile_dims=sharding_proto.last_tile_dims))
|
||
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
|
||
val, use_sharding_op=True)
|
||
|
||
|
||
def _pjit(*args: TfVal,
|
||
jaxpr: core.ClosedJaxpr,
|
||
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||
resource_env: maps.ResourceEnv,
|
||
donated_invars,
|
||
name: str,
|
||
in_positional_semantics,
|
||
out_positional_semantics,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray) -> TfVal:
|
||
del donated_invars
|
||
if resource_env.physical_mesh.is_multi_process:
|
||
raise NotImplementedError("jax2tf translation for pjit over multi-process "
|
||
"meshes is not supported yet")
|
||
# TODO: add `name` to the name stack
|
||
shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh)
|
||
# Apply sharding annotation to the arguments
|
||
sharded_args: Sequence[TfVal] = tuple(
|
||
map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
|
||
results = _interpret_jaxpr(jaxpr, *sharded_args,
|
||
extra_name_stack=util.wrap_name(name, "pjit"))
|
||
sharded_results: Sequence[TfVal] = tuple(
|
||
map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
|
||
return tuple(sharded_results)
|
||
|
||
|
||
tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||
|
||
|
||
def _pjit_sharding_constraint(arg: TfVal, *,
|
||
axis_resources: pjit.ParsedPartitionSpec,
|
||
resource_env: maps.ResourceEnv,
|
||
_in_avals: Sequence[core.ShapedArray],
|
||
_out_aval: core.ShapedArray,
|
||
**kwargs) -> TfVal:
|
||
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
|
||
|
||
|
||
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||
|
||
def _dimension_size_jax2tf(op: TfVal, *, dimension):
|
||
return tf.shape(op)[dimension]
|
||
|
||
tf_impl[shape_poly.dimension_size_p] = _dimension_size_jax2tf
|
||
|
||
def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
|
||
dim_tf, = _eval_shape((dim,))
|
||
return dim_tf
|
||
|
||
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
|
||
|
||
def _register_checkpoint_pytrees():
|
||
"""Registers TF custom container types as pytrees."""
|
||
m = tf.Module()
|
||
# The types here are automagically changed by TensorFlow's checkpointing
|
||
# infrastructure.
|
||
m.a = (tf.Module(), tf.Module())
|
||
m.b = [tf.Module(), tf.Module()]
|
||
m.c = {"a": tf.Module()}
|
||
tuple_wrapper = type(m.a)
|
||
list_wrapper = type(m.b)
|
||
dict_wrapper = type(m.c)
|
||
|
||
# TF AutoTrackable swaps container types out for wrappers.
|
||
assert tuple_wrapper is not tuple
|
||
assert list_wrapper is not list
|
||
assert dict_wrapper is not dict
|
||
|
||
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
|
||
(tuple(xs), None), lambda _, xs: tuple(xs))
|
||
|
||
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
|
||
lambda _, xs: list(xs))
|
||
|
||
jax.tree_util.register_pytree_node(
|
||
dict_wrapper,
|
||
lambda s: (tuple(s.values()), tuple(s.keys())),
|
||
lambda k, xs: dict_wrapper(zip(k, xs)))
|
||
|
||
|
||
_register_checkpoint_pytrees()
|