George Necula 235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00

2250 lines
94 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
import functools
import re
import string
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import jax
from jax import ad_util, api_util, config
from jax._src import api
from jax import core, custom_derivatives, dtypes
from jax import linear_util as lu
from jax import numpy as jnp
from jax import random, tree_util
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import fft as lax_fft
from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg
import jax._src.random
from jax.api_util import flatten_fun
from jax.interpreters import ad
from jax.interpreters import pxla
from jax.interpreters import sharded_jit
from jax.interpreters import xla
from jax.lib import xla_client
from . import shape_poly
import numpy as np
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
PolyShape = shape_poly.PolyShape
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/>-]")
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
if not _VALID_SCOPE_REGEX.match(scope_name):
scope_name = ".{}".format(scope_name)
return scope_name
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Note: this conversion is overkill and just intended as a type check; this
# code is in principle only run if config.jax_enable_checks is True.
# TODO: it is not true that this code is run only with jax_enable_checks.
_safe_convert_to_tensor(v)
return True
except ValueError:
return False
def _safe_convert_to_tensor(val, dtype=None) -> TfVal:
dtype = dtype if dtype else (val.dtype if hasattr(val, "dtype") else None)
conversion_type = to_tf_dtype(dtype) if dtype else None
# The float0 type is not known to TF.
if dtype and dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_type.as_numpy_dtype)
return tf.convert_to_tensor(val, dtype=conversion_type)
# The implementation rules for primitives. The rule will be called with the
# arguments (TfVal) and must return TfVal (or a sequence thereof,
# if primitive.multiple_results). The vast majority of primitives do not need
# to worry about core.unit inputs or results. The exception are primarily the
# control-flow primitives.
tf_impl: Dict[core.Primitive,
Callable[..., Any]] = {}
# Some primitive implementation rules need the abstract values of arguments
# and the results. This is the case for the primitives implemented using
# _convert_jax_impl and those that need to adjust the shape of the outputs
# due to missing TF shape inference rules for TFXLA ops. The rules for these
# primitives should be added to `tf_impl_with_avals`.
# The abstract value are passed to the implementation as two special kwargs
# `_in_avals` (a tuple of core.AbstractValue) and `_out_aval` (a
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: Dict[core.Primitive,
Callable[..., Any]] = {}
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
_enable_xla = True
def _xla_path_disabled_error(primitive_name: str) -> Exception:
assert not _enable_xla
return NotImplementedError(
f"Call to {primitive_name} can only be converted through TFXLA, but "
"XLA is disabled")
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable, *,
polymorphic_shapes: Optional[Sequence[Any]]=None,
with_gradient=True, enable_xla=True) -> Callable:
"""Transforms `fun` to be executed by TensorFlow.
See [README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun: Function to be transformed. Its arguments and return value should be
JAX arrays, or nested standard Python containers (tuple/list/dict)
thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during conversion.
.. warning::
The shape-polymorphic conversion is an experimental feature. It is meant
to be sound, but it is known to reject some JAX programs that are
shape polymorphic. The details of this feature can change.
It should be a Python object with the same pytree structure as,
or a prefix of, the tuple of arguments to the function,
but with a shape specification corresponding to each argument.
The default value is `None`, which is a shortcut for a tuple of `None`
one for each argument, denoting that all shapes are monomorphic.
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument
should be an object `PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size,
or a string denoting a dimension variable assumed to range over non-zero
dimension sizes,
or the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument.
As a shortcut, an Ellipsis suffix in the
list of dimension specifications stands for a list of "_" placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The conversion fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
See [the README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
with_gradient: if set, will add a tf.custom_gradient to the converted
function, by converting the ``jax.vjp(fun)``. Only first-order
differentiation is supported for now. If the converted function is
saved in a SavedModel, the custom gradients are currently lost and
an error will be raised if a gradient computation is attempted.
This is due to a current bug in TensorFlow.
enable_xla: if unset, the converter will try harder to use pure TF ops to
convert the function, and raise an error if it can not be converted
without resorting to XLA ops (default: True).
Returns:
A version of `fun` that expects TfVals as arguments (or
tuple/lists/dicts) thereof, and returns TfVals as outputs.
"""
global _enable_xla
_enable_xla = enable_xla
api._check_callable(fun)
def converted_fun(*args: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean():
raise ValueError("convert must be used outside all JAX transformations."
+ f"Trace state: {core.thread_local_state.trace_state}")
def check_arg(a):
if not _is_tfval(a):
msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
tree_util.tree_map(check_arg, args)
# Name input tensors
args = tuple(
tree_util.tree_map(lambda x, i=i: tf.identity(x, f"jax2tf_arg_{i}"), a) # type: ignore
for i, a in enumerate(args))
# This function may take pytrees of TfVals. We can only set
# tf.custom_gradient on functions that take a flat argument list.
args_flat, in_tree = tree_util.tree_flatten((args, {}))
if polymorphic_shapes is None:
polymorphic_shapes_ = (None,) * len(args)
else:
if not isinstance(polymorphic_shapes, Sequence) or len(args) != len(polymorphic_shapes):
msg = ("polymorphic_shapes must be a sequence with the same length as the argument list "
f"({len(args)}). Got polymorphic_shapes={polymorphic_shapes}.")
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
# Expand the polymorphic_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0],
polymorphic_shapes_))
# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the polymorphic_shapes if given. May create new shape
# variables.
args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat,
polymorphic_shapes_flat)
f = lu.wrap_init(fun)
# out_tree_thunk() will be the output tree, after running _interpret_fun.
flat_fun, out_tree_thunk = flatten_fun(f, in_tree)
# Prepare the grad_fn for tf.custom_gradient.
def converted_grad_fn(*out_cts_flat: TfVal,
_out_cts_avals: Sequence[core.AbstractValue],
variables=None):
if variables:
raise ValueError("Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"variables={variables}")
def fun_vjp_jax(args_jax, out_cts_jax):
# One may think that we can get the pullback while we are converting
# the main function in the first place. That is problematic, because the
# pullback may contain captured tracers from the conversion of the
# main function. Those tracers will confuse the conversion of the
# pullback. So, we construct the vjp anew.
_, pullback_jax = jax.vjp(fun, *args_jax)
return pullback_jax(out_cts_jax)
if polymorphic_shapes is None:
vjp_polymorphic_shapes = None
else:
args_polymorphic_shapes = tree_util.tree_unflatten(in_tree.children()[0], polymorphic_shapes_flat)
out_cts_polymorphic_shapes = tree_util.tree_unflatten(
out_tree_thunk(),
tuple(str(out_aval.shape) for out_aval in _out_cts_avals)) # type: ignore
vjp_polymorphic_shapes = [args_polymorphic_shapes, out_cts_polymorphic_shapes]
out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat)
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
in_cts = convert(fun_vjp_jax, with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args, out_cts)
return in_cts
try:
global _shape_env
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
_shape_env = shapeenv
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat)
outs, out_avals = util.unzip2(out_with_avals)
return (tuple(outs),
functools.partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))
out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
else:
out_flat_raw = _interpret_fun(flat_fun, args_flat, args_avals_flat)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
out_flat = [tf.raw_ops.PreventGradient(input=o, message=message)
for o, _ in out_flat_raw]
finally:
_shape_env = {}
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
return out
return converted_fun
# Internals
def _interpret_fun(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
in_avals: Sequence[core.AbstractValue]
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = \
fun.call_wrapped(*in_vals)
del main
return tuple(out_vals)
def _convert_jax_impl(jax_impl: Callable, *, multiple_results=True) -> Callable:
"""Convert the JAX implementation of a primitive.
Args:
jax_impl: typically the impl-rule for a primitive, with signature
`(*args: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
a primitive in terms of other primitives.
multiple_results: whether `jax_impl` returns a sequence of results.
Returns:
a function with signature `(*args: TfVal, _in_avals, _out_aval, **kwargs) -> Sequence[TfVal]`.
"""
def wrapped(*tf_args: TfVal,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue, **kwargs) -> Sequence[TfVal]:
# We wrap the jax_impl under _interpret_fun to abstract the TF values
# from jax_impl and turn them into JAX abstract values.
def jax_impl_jax_args(*jax_args):
jax_results = jax_impl(*jax_args, **kwargs)
return jax_results if multiple_results else [jax_results]
tf_results_with_avals = _interpret_fun(lu.wrap_init(jax_impl_jax_args), tf_args, _in_avals)
tf_results, _ = util.unzip2(tf_results_with_avals)
return tf_results if multiple_results else tf_results[0]
return wrapped
@lu.transformation
def _interpret_subtrace(main: core.MainTrace,
in_avals: Sequence[core.AbstractValue],
*in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(TensorFlowTracer(trace, val, aval)
for val, aval in util.safe_zip(in_vals, in_avals))
# The outs may be core.unit, see comment in TensorFlowTrace.pure.
outs = yield in_tracers, {} # type: Sequence[Union[TfVal, core.Unit]]
out_tracers: Iterable[TensorFlowTracer] = map(trace.full_raise, outs) # type: ignore
out_vals_with_avals: Sequence[Tuple[TfVal, core.AbstractValue]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
"""Evaluates a Jaxpr with tf.Tensor arguments.
The output is a sequence of TfVal (no `core.unit`), suitable for use with TF.
"""
fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
out_with_avals = _interpret_fun(fun, args, jaxpr.in_avals)
return tuple(v for v, _ in out_with_avals)
### tracer
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
return tuple(map(lambda d: None if isinstance(d, shape_poly.DimVar) else d,
aval.shape)) # type: ignore[attr-defined]
def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
"""
Called for constants that occur in the program, or for input values to the
converted function. The returned shape may have unknown components, but
only when called for inputs.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
# May be partially known
return tuple(val.shape), to_jax_dtype(val.dtype)
else: # Must be a numeric value
assert not config.jax_enable_checks or _is_tfval(val), f"Non TfVal: {val}"
raw_aval = xla.abstractify(val)
return raw_aval.shape, raw_aval.dtype # type: ignore[attr-defined]
# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# function arguments.
_ShapeEnv = Dict[shape_poly.DimVar, TfVal]
def _args_to_avals_and_env(args: Sequence[TfVal],
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
Tuple[Sequence[core.AbstractValue], _ShapeEnv]:
"""Computes abstract values and a dimension environment for arguments.
Args:
args: the arguments, TF inputs.
polymorphic_shapes: the polymorphic specifications for the arguments.
Returns: a tuple of a sequence of abtract values corresponding to the arguments
and a dimension environment.
"""
shapeenv: _ShapeEnv = {}
def input_aval(arg: TfVal, polymorphic_shape: Optional[str]) -> core.AbstractValue:
"""The abstract value for an input."""
raw_shape, dtype = _tfval_shape_dtype(arg)
aval_shape = shape_poly.parse_spec(polymorphic_shape, raw_shape)
for i, d in enumerate(aval_shape):
if type(d) is int:
assert d == np.shape(arg)[i]
elif type(d) is shape_poly.DimVar and d not in shapeenv:
# Even if the shape of `arg` is known, we still use `tf.shape` for
# safety, because the promise is that we will convert the function
# to work for any value of the dimension.
shapeenv[d] = tf.shape(arg)[i] # type: ignore[index]
else:
# TODO: add an assertion tf.shape(arg)[i] == env[d]
pass
return core.ShapedArray(aval_shape, dtype)
avals = tuple(map(input_aval, args, polymorphic_shapes)) # type: ignore
return avals, shapeenv
# A shape environment maps shape variables to TfVal.
_shape_env = {} # type: _ShapeEnv
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
return tuple(_shape_env[d] if type(d) is shape_poly.DimVar else d # type: ignore[index]
for d in shape)
def shape_as_value(x):
"""Injects the shape of `x` as an array value.
**Experimental: please give feedback, and expect changes!**
This allows the use of a shape expression as array argument to JAX functions.
A typical example is for implementing a mean operation:
jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
"""
# return shape_as_value_p.bind(x)
return NotImplementedError("shape_as_value is deprecated")
# # TODO: move this to masking or to some common library, if approved
# shape_as_value_p = core.Primitive("shape_as_value")
# shape_as_value_p.multiple_results = True
# def _shape_as_value_impl(x):
# x_shape = np.shape(x)
# def dim_to_int(dim: shape_poly.DimSize) -> int:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is None:
# msg = ("shape_as_value is not implemented for non-constant shapes "
# "except for masking and jax2tf. "
# f"Has shape: {x_shape}")
# raise TypeError(msg)
# else:
# return dim_int
# return tuple(map(dim_to_int, x_shape))
#
# shape_as_value_p.def_impl(_shape_as_value_impl)
#
# def _shape_as_value_abstract(x_aval: core.AbstractValue) -> Sequence[core.AbstractValue]:
# rank = len(x_aval.shape) # type: ignore[attr-defined]
# return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
#
# shape_as_value_p.def_abstract_eval(_shape_as_value_abstract)
#
# def _shape_as_value_translation(comp, x):
# return xla_client._xla.ops.Tuple(comp,
# tuple(xb.constant(comp, d)
# for d in comp.GetShape(x).dimensions()))
#
# xla.translations[shape_as_value_p] = _shape_as_value_translation
#
# def _shape_as_value_jvp_rule(primals, tangents):
# # The shape does not depend on the contents of the input
# x, = primals
# zero = ad.Zero.from_value(0.)
# return shape_as_value(x), (zero,) * len(x.shape)
#
# ad.primitive_jvps[shape_as_value_p] = _shape_as_value_jvp_rule
#
# def _shape_as_value__batching_rule(batched_args, batch_dims):
# xv, = batched_args
# batch_dim, = batch_dims
# batch_size = xv.shape[batch_dim]
# batched_shape = shape_as_value(xv)
# one_shape = batched_shape[0:batch_dim] + batched_shape[batch_dim+1:]
# res = tuple(jnp.broadcast_to(d, (batch_size, 1)) for d in one_shape)
# return res, (0,) * len(one_shape)
#
# batching.primitive_batchers[shape_as_value_p] = _shape_as_value__batching_rule
#
# def _shape_as_value_masking_rule(operands, operands_logical_shapes):
# x_logical_shape, = operands_logical_shapes
# return tuple(x_logical_shape)
#
# masking.masking_rules[shape_as_value_p] = _shape_as_value_masking_rule
#
# def _shape_as_value_tf(x: TfVal,
# _in_avals: Sequence[core.AbstractValue],
# _out_aval: core.AbstractValue) -> TfVal:
# x_aval = _in_avals[0]
# def dim_to_tfval(dim: shape_poly.DimSize, dim_idx: int) -> TfVal:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is not None:
# return tf.convert_to_tensor(dim_int)
# else:
# return tf.shape(x)[dim_idx]
# return tuple(dim_to_tfval(dim, dim_idx)
# for dim_idx, dim in enumerate(x_aval.shape)) # type: ignore[attr-defined]
#
# tf_impl_with_avals[shape_as_value_p] = _shape_as_value_tf
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot
class TensorFlowTracer(core.Tracer):
"""Tracer class that boxes a TF value and a JAX abstract value.
In addition to the TF value we carry the JAX abstract value because there are
two cases when it cannot be recovered from the value: (a) when the abstract
value is core.abstract_unit, in which case the value is tf.nan; (b) when we
are converting with polymorphic shapes, in which case the shape of the value
may have dimensions set to `None`, which the JAX abstract value may contain
more precise information.
When the value has a partially-known shape, the dimensions marked as `None`
must correspond to non-constant dimensions in the abstract value.
See README.md for details.
"""
# val: TfVal
# _aval: core.AbstractValue
__slots__ = ["val", "_aval"]
def __init__(self, trace: 'TensorFlowTrace', val: TfVal,
aval: core.AbstractValue):
self._trace = trace
self._aval = aval
if aval is core.abstract_unit:
self.val = val
elif isinstance(val, (tf.Tensor, tf.Variable)):
val_shape, val_dtype = _tfval_shape_dtype(val)
aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined]
if (val_dtype != aval_dtype and
not config.x64_enabled and
(val_dtype == tf.int32 and aval_dtype == jnp.int64 or
val_dtype == tf.int64 and aval_dtype == jnp.int32 or
val_dtype == tf.float32 and aval_dtype == jnp.float64 or
val_dtype == tf.float64 and aval_dtype == jnp.float32 or
val_dtype == tf.complex128 and aval_dtype == jnp.complex64)):
# If JAX does not have x64 bit mode enabled, it will force the 64-bit
# values to use 32-bit precision. In order to make the TF conversion
# follow JAX's rules, we cast the TF values down to 32-bit mode.
val = tf.cast(val, dtype=aval_dtype)
val_dtype = aval_dtype
if config.jax_enable_checks:
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert isinstance(aval_dim,
shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not isinstance(aval_dim, shape_poly.DimVar):
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
else:
# We have a TF value with known shape, and the abstract shape is a shape variable.
try:
aval_int = int(_eval_shape([aval_dim])) # type: ignore
except TypeError:
continue
assert aval_int == val_dim, f"expected {self._aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
self.val = val
else: # Must be a numeric value
self.val = _safe_convert_to_tensor(val, dtype=self._aval.dtype) # type: ignore[attr-defined]
@property
def aval(self):
return self._aval
def full_lower(self):
return self
class TensorFlowTrace(core.Trace):
"""Trace class that underlies the jax2tf transformation.
We are going to ensure that jax2tf.convert is never nested inside other
transformations. This is sufficient for intended use cases (converting
fully-transformed JAX code). It also simplifies our job because we do not have
to handle situations where we apply primitives on a mix of TF values and
JAX tracers from an outer transformation. E.g., for addition both the TF values
and the JAX tracers have an override and they get confused if they see values
from the other world.
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
higher-order control-flow primitives we invoke recursively
_interpret_fun on the body of the conditional, which will create a nested TFT.
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT.
"""
def pure(self, val: Union[TfVal, core.Unit]) -> TensorFlowTracer:
"""Lifts a non-Tracer into the TensorFlowTracer.
This function may be called by way of trace.full_raise.
The value may be a core.unit. During JAX transformations we sometimes
produce a Jaxpr that has arguments of abstract value core.abstract_unit
and results equal to core.unit. These are arguments and results that are
not used in the computation.
In TF world, we represent core.unit as NaN. This is safe, as these values
should never be used.
"""
if val is core.unit:
return TensorFlowTracer(self, tf.constant(np.nan, tf.float32), core.abstract_unit)
else:
shape, dtype = _tfval_shape_dtype(val)
return TensorFlowTracer(self, val, core.ShapedArray(shape, dtype))
def lift(self, val: core.Tracer) -> TensorFlowTracer:
# This would be called when we need to raise a tracer from a lower-level
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
# inside another transform, there are no lower-level main traces.
assert False
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
# This is called when we need to raise a tracer from the same master,
# but a lower sublevel. This could come from a nested jit.
return TensorFlowTracer(self, val.val, val._aval)
def process_primitive(self, primitive: core.Primitive,
tracers: Sequence[TensorFlowTracer],
params) -> TensorFlowTracer:
impl, impl_needs_avals = self.get_primitive_impl(primitive)
args_avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
out_aval = primitive.abstract_eval(*args_avals, **params)
args_tf: Sequence[TfVal] = [t.val for t in tracers]
if impl_needs_avals:
val_out: TfVal = impl(*args_tf, _in_avals=args_avals, # type: ignore
_out_aval=out_aval, **params)
else:
val_out = impl(*args_tf, **params)
if primitive.multiple_results:
out = [TensorFlowTracer(self, v, a)
for v, a in util.safe_zip(val_out, out_aval)] # type: ignore
else:
out = TensorFlowTracer(self, val_out, out_aval) # type: ignore
# Check that the impl rule returned a value of expected shape and dtype
# TODO: adapt this to match polymorphic shapes
if config.jax_enable_checks:
if primitive.multiple_results:
for o, expected_aval in zip(out, out_aval): # type: ignore
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
f"{primitive}: out.aval = {o.aval}; expected {expected_aval}")
else:
assert out.aval == out_aval, ( # type: ignore
f"{primitive}: out.aval = {out.aval}; expected {out_aval}") # type: ignore
return out # type: ignore
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers]
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
with core.new_sublevel():
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
vals_out: Sequence[Tuple[TfVal, core.AbstractValue]] = \
f.call_wrapped(*vals)
elif call_primitive == sharded_jit.sharded_call_p:
vals_out = _sharded_call(f, vals, **params)
else:
vals_out = f.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
def post_process_call(self, call_primitive: core.Primitive,
out_tracers: Sequence[TensorFlowTracer], params):
# We encountered a call primitive, e.g., remat_call_p, whose result
# (out_tracers) include TensorFlowTracer that were not passed through
# its arguments (captured from the environment).
vals = tuple(t.val for t in out_tracers)
main = self.main
def todo(vals: Sequence[TfVal]):
trace = TensorFlowTrace(main, core.cur_sublevel())
return [TensorFlowTracer(trace, v, out_tracer.aval)
for v, out_tracer in util.safe_zip(vals, out_tracers)]
return vals, todo
def process_map(self, map_primitive, f, tracers, params):
raise NotImplementedError("process_map")
def post_process_map(self, map_primitive, out_tracers, params):
raise NotImplementedError("post_process_map")
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del jvp # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_jvp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del fwd, bwd, out_trees # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
# Returns the primitive implementation and whether the implementation
# takes abstract values (see definition of tf_impl_with_avals)
try:
return tf_impl[p], False
except KeyError:
try:
return tf_impl_with_avals[p], True
except KeyError as err:
msg = "TensorFlow interpretation rule for '{}' not implemented"
raise NotImplementedError(msg.format(p)) from err
def to_tf_dtype(jax_dtype):
if jax_dtype == dtypes.float0:
jax_dtype = dtypes.bfloat16
return tf.dtypes.as_dtype(jax_dtype)
def to_jax_dtype(tf_dtype):
return tf_dtype.as_numpy_dtype
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
for unexpected in xla.call_translations: # Call primitives are inlined
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"reduce", "rng_uniform", "clz",
"igamma_grad_a",
"random_gamma_grad",
"reduce_precision",
# Not high priority?
"after_all", "all_to_all", "create_token",
"infeed", "outfeed", "pmax_p",
"pmin", "ppermute", "psum", "pmax", "pgather",
"axis_index", "pdot", "all_gather",
"lu_pivots_to_permutation",
"rng_bit_generator",
"xla_pmap",
"call_tf",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
def _add(x: TfVal, y: TfVal) -> TfVal:
return tf.raw_ops.AddV2(x=x, y=y)
tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[xla.device_put_p] = lambda x, device=None: x
tf_impl[lax.neg_p] = tf.math.negative
tf_impl[lax.sign_p] = tf.math.sign
tf_impl[lax.floor_p] = tf.math.floor
tf_impl[lax.ceil_p] = tf.math.ceil
def _round(operand, *, rounding_method):
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
sign = tf.math.sign(operand)
operand *= sign
floor = tf.math.floor(operand)
operand -= floor
cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype))
return sign * (tf.where(cond, tf.constant(np.array(1), operand.dtype),
tf.math.round(operand)) + floor)
else:
return tf.math.round(operand)
tf_impl[lax.round_p] = _round
tf_impl[lax.nextafter_p] = tf.math.nextafter
def _population_count(x):
orig_dtype = x.dtype
return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype)
tf_impl[lax.population_count_p] = _population_count
tf_impl[lax.is_finite_p] = tf.math.is_finite
tf_impl[lax.abs_p] = tf.math.abs
tf_impl[lax.pow_p] = tf.math.pow
tf_impl[lax.integer_pow_p] = tf.math.pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
tf_impl[lax.tan_p] = tf.math.tan
tf_impl[lax.tanh_p] = tf.math.tanh
tf_impl[lax.sin_p] = tf.math.sin
tf_impl[lax.sinh_p] = tf.math.sinh
tf_impl[lax.cos_p] = tf.math.cos
tf_impl[lax.cosh_p] = tf.math.cosh
tf_impl[lax.acos_p] = tf.math.acos
tf_impl[lax.asin_p] = tf.math.asin
tf_impl[lax.atan_p] = tf.math.atan
tf_impl[lax.atan2_p] = tf.math.atan2
tf_impl[lax.acosh_p] = tf.math.acosh
tf_impl[lax.atanh_p] = tf.math.atanh
tf_impl[lax.asinh_p] = tf.math.asinh
tf_impl[lax.sqrt_p] = tf.math.sqrt
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
tf_impl[lax.lgamma_p] = tf.math.lgamma
tf_impl[lax.digamma_p] = tf.math.digamma
tf_impl[lax.igamma_p] = tf.math.igamma
tf_impl[lax.igammac_p] = tf.math.igammac
tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc
tf_impl[lax.erf_p] = tf.math.erf
tf_impl[lax.erfc_p] = tf.math.erfc
tf_impl[lax.erf_inv_p] = tf.math.erfinv
tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
tf_impl[lax.complex_p] = tf.complex
def _conj(x, **kwargs):
# The only dtypes that are allowed are: float32, float64, complex64, and
# complex128.
if x.dtype == tf.float32:
return tf.cast(x, tf.complex64)
elif x.dtype == tf.float64:
return tf.cast(x, tf.complex128)
else:
return tf.math.conj(x)
tf_impl[lax.conj_p] = _conj
tf_impl[lax.real_p] = tf.math.real
tf_impl[lax.imag_p] = tf.math.imag
tf_impl[lax.add_p] = _add
tf_impl[lax.sub_p] = tf.math.subtract
tf_impl[lax.mul_p] = tf.math.multiply
def _iota(*, dtype, shape, dimension):
dtype = to_tf_dtype(dtype)
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
# TODO(mattjj, necula): improve tf.range dtype handling
shape_tf = _eval_shape(shape)
vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32)
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype)
tf_impl[lax.iota_p] = _iota
def _div(lhs, rhs):
if lhs.dtype.is_integer:
quotient = tf.math.floordiv(lhs, rhs)
select = tf.math.logical_and(
tf.not_equal(tf.math.sign(lhs), tf.math.sign(rhs)),
tf.not_equal(tf.math.floormod(lhs, rhs), 0))
return tf.where(select, quotient + 1, quotient)
else:
return tf.math.truediv(lhs, rhs)
def _rem(lhs, rhs):
return tf.math.sign(lhs) * tf.math.floormod(tf.math.abs(lhs),
tf.math.abs(rhs))
tf_impl[lax.div_p] = _div
tf_impl[lax.rem_p] = _rem
tf_impl[lax.max_p] = tf.math.maximum
tf_impl[lax.min_p] = tf.math.minimum
# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
tf.int8: tf.uint8,
tf.int16: tf.uint16,
tf.int32: tf.uint32,
tf.int64: tf.uint64,
}
# Map from TF unsigned types to TF signed types.
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}
# Note: Bitwise operations only yield identical results on unsigned integers!
# pylint: disable=protected-access
def _shift_right_arithmetic_raw(x, y):
if x.dtype.is_unsigned:
assert x.dtype == y.dtype
orig_dtype = x.dtype
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
x = tf.cast(x, signed_dtype)
y = tf.cast(y, signed_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
else:
return tf.bitwise.right_shift(x, y)
def _shift_right_arithmetic(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA
# semantics to return the shift by the max value (x_bits - 1).
# TODO: it is likely better to add XlaOps for shifts
x_bits = 8 * x.dtype.size
clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1)
return _shift_right_arithmetic_raw(x, clamp_y)
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic
def _shift_right_logical_raw(x, y):
if x.dtype.is_unsigned:
return tf.bitwise.right_shift(x, y)
else:
assert x.dtype == y.dtype
orig_dtype = x.dtype
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
x = tf.cast(x, unsigned_dtype)
y = tf.cast(y, unsigned_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
def _shift_right_logical(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(_shift_in_bounds(x, y),
_shift_right_logical_raw(x, y),
tf.zeros_like(x))
tf_impl[lax.shift_right_logical_p] = _shift_right_logical
def _shift_left(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(_shift_in_bounds(x, y),
tf.bitwise.left_shift(x, y),
tf.zeros_like(x))
tf_impl[lax.shift_left_p] = _shift_left
def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal:
# Return the TF expression for when y is within bounds (0 <= y < |x|)
x_bits = 8 * x.dtype.size
# TF does not have comparisons for uint16 and uint32 (despite what the
# documentation says)
y_comp = tf.cast(y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y
y_lt_x_bits = tf.math.less(y_comp, x_bits)
y_ge_0 = tf.math.greater_equal(y_comp, 0)
return tf.logical_and(y_lt_x_bits, y_ge_0)
def _not(x):
"""Computes bitwise not with support for booleans.
Numpy and JAX support bitwise not for booleans by applying a logical not!
This means that applying bitwise_not yields an unexected result:
jnp.bitwise_not(jnp.array([True, False]))
>> DeviceArray([False, True], dtype=bool)
if you assume that booleans are simply casted to integers.
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
>> DeviceArray([True, True], dtype=bool)
"""
if x.dtype == tf.bool:
return tf.logical_not(x)
else:
return tf.bitwise.invert(x)
tf_impl[lax.not_p] = _not
def bool_to_int8(f, argnums):
"""Computes bool valued functions using int8."""
argnums = tf.nest.flatten(argnums)
def wrapper(*args, **kwargs):
if not any(args[i].dtype == tf.bool for i in argnums):
return f(*args, **kwargs)
else:
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
for i, a in enumerate(args)]
if "_in_avals" in kwargs:
def cast_aval(aval):
return core.ShapedArray(aval.shape, np.int8)
_in_avals_cast = [cast_aval(aval) if i in argnums else aval
for i, aval in enumerate(kwargs["_in_avals"])]
_out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"])
kwargs = dict(kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast)
out = f(*args_cast, **kwargs)
return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out)
return wrapper
tf_impl[lax.or_p] = bool_to_int8(tf.bitwise.bitwise_or, argnums=(0, 1))
tf_impl[lax.and_p] = bool_to_int8(tf.bitwise.bitwise_and, argnums=(0, 1))
tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal
tf_impl[lax.ge_p] = tf.math.greater_equal
tf_impl[lax.gt_p] = tf.math.greater
tf_impl[lax.le_p] = tf.math.less_equal
tf_impl[lax.lt_p] = tf.math.less
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
def _convert_element_type(operand, *, new_dtype, weak_type=False):
old_dtype = operand.dtype.as_numpy_dtype
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = tf.math.real(operand)
if (dtypes.issubdtype(old_dtype, np.floating) and
not (dtypes.issubdtype(new_dtype, np.floating) or
dtypes.issubdtype(new_dtype, np.complexfloating) or
new_dtype == np.bool_)):
sign = tf.math.sign(operand)
operand = sign * tf.math.floor(sign * operand)
return tf.dtypes.cast(operand, to_tf_dtype(new_dtype))
tf_impl[lax.convert_element_type_p] = _convert_element_type
def _bitcast_convert_type(operand, new_dtype):
return tf.bitcast(operand, to_tf_dtype(new_dtype))
tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
# The below permits mirroring the behavior of JAX when maxval < minval
op_shape_tf_val = _eval_shape(_in_avals[1].shape)
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
return tf.clip_by_value(operand, minval, maxval)
tf_impl_with_avals[lax.clamp_p] = _clamp
def _concatenate(*operands, dimension):
return tf.concat(operands, axis=dimension)
tf_impl[lax.concatenate_p] = _concatenate
def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
proto = xla_data_pb2.ConvolutionDimensionNumbers()
proto.input_batch_dimension = lhs_spec[0]
proto.input_feature_dimension = lhs_spec[1]
proto.output_batch_dimension = out_spec[0]
proto.output_feature_dimension = out_spec[1]
proto.kernel_output_feature_dimension = rhs_spec[0]
proto.kernel_input_feature_dimension = rhs_spec[1]
proto.input_spatial_dimensions.extend(lhs_spec[2:])
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
proto.output_spatial_dimensions.extend(out_spec[2:])
return proto
def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]):
"""Convert an integer to an XLA.PrecisionConfig."""
if precision is None:
return None
proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(int(precision[0]))
proto.operand_precision.append(int(precision[1]))
return proto
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why
# it did not succeed otherwise.
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
out_shape) -> Union[str, TfVal]:
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
# can be translated into TF primitives. Further investigation is needed to
# fully flesh it out.
if not lhs.dtype in [tf.float16, tf.float32, tf.float64]:
return f"tf.nn.convolution is not supported for dtype {lhs.dtype}"
if feature_group_count != 1:
return "tf.nn.convolution does not support grouped convolutions"
# TODO(bchetioui): is there something to do with batch_group_count?
if batch_group_count != 1:
return "Unimplemented support for batch_group_count != 1"
nb_spatial_dimensions = len(lhs.shape) - 2
# TF can only deal with 1D, 2D and 3D convolution
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
return ("TensorFlow can only handle convolutions with 1, 2, or 3 "
"spatial dimensions")
# TODO(bchetioui): handle different stride cases
if list(window_strides) != [1] * nb_spatial_dimensions:
return ("Unimplemented support for window_strides != "
f"{tuple([1] * nb_spatial_dimensions)}")
success = lambda res: (res, None)
failure = lambda msg: (None, msg)
def convert_padding():
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
# string padding is not implemented for transposed convolution.
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
return failure("Padding conversion is not supported for transposed "
"convolution.")
lhs_perm, rhs_perm, _ = dimension_numbers
effective_rhs_shape = [(k-1) * r + 1 for k, r in
zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)]
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
# TF only allows 'VALID' and 'SAME' padding
for pad_str in ['VALID', 'SAME']:
gen_padding = lax.padtype_to_pads(
lhs_shape, effective_rhs_shape, window_strides, pad_str)
if list(gen_padding) == list(padding):
return success(pad_str)
return failure("Input padding not supported in TensorFlow.")
def convert_dim_nums():
lhs_spec, rhs_spec, out_spec = dimension_numbers
# TF only allows filters with shape:
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
# rhs_spec is represented as a tuple containing the following:
# [out_channels, in_channels] + spatial_filter_shape.
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
list(range(nb_spatial_dimensions)))
if list(rhs_spec) != supported_rhs_shape:
return failure("Input filter (RHS) shape format not supported in "
"TensorFlow")
# TF only supports same LHS and output data format
if lhs_spec != out_spec:
return failure("TensorFlow requires the same data format for LHS and "
"output.")
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
spatial_dim_alphabet = 'DHW'[-nb_spatial_dimensions:]
# TF only supports the following data formats:
# - [batch_size, in_channels] + input_spatial_shape
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
# failing on this platform, this path is commented out for now.
#if list(lhs_spec) == list(range(len(lhs_spec))):
# return "NC" + spatial_dim_alphabet
# - [batch_size] + input_spatial_shape + [in_channels]
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
list(range(1, len(lhs_spec) - 1))):
return success("N" + spatial_dim_alphabet + "C")
return failure("Data format is unsupported by TensorFlow")
def convert_dilation_and_compute_result(tf_padding, tf_dim_nums):
no_dilation = [1] * nb_spatial_dimensions
# TODO(bchetioui): is there a generic way to do a transposed atrous
# convolution in TensorFlow?
if not (list(lhs_dilation) == no_dilation or
list(rhs_dilation) == no_dilation):
return "Both LHS and RHS dilations are set"
# This is a non-dilated or atrous convolution
if list(lhs_dilation) == no_dilation:
return tf.nn.convolution(
lhs, rhs, strides=window_strides, padding=tf_padding,
data_format=tf_dim_nums, dilations=rhs_dilation)
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
# dilation to this function will result in convert_padding returning None
# systematically. This must be investigated further.
# Dilation of the LHS is transposed convolution
return tf.nn.conv_transpose(
lhs, rhs, out_shape, window_strides, padding=tf_padding,
data_format=tf_dim_nums, dilations=lhs_dilation)
tf_padding, error = convert_padding()
if tf_padding is None:
return error
tf_dim_nums, error = convert_dim_nums()
if tf_dim_nums is None:
return error
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
def _conv_general_dilated(lhs, rhs, *,
window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, lhs_shape, rhs_shape,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type, _in_avals, _out_aval):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
if not _enable_xla:
info_or_result = _try_tf_conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, _aval_to_tf_shape(_out_aval)
)
if not isinstance(info_or_result, str):
return info_or_result
else:
raise _xla_path_disabled_error("conv_general_dilated")
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _precision_config_proto(precision)
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
out = tfxla.conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dnums_proto, feature_group_count=feature_group_count,
precision_config=precision_config_proto)
# TODO: implement shape inference for XlaConv
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *,
dimension_numbers,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
if _enable_xla:
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
res = tfxla.dot_general(lhs, rhs, dnums_proto, precision_config_proto,
preferred_element_type=preferred_element_type)
# TODO: in presence of None dimensions, XlaDot shape inference returns
# unknown shape.
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
# This condition ensures that:
# 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
# not strictly necessary, but we would have to reshape the array if that
# were not the case;
# 2) lhs and rhs have the same number of dimensions +/- 1
# 3) the number of non-batch dimensions in both tensors is either 1 or 2
# 4) the contracting dimensions are consistent with those of a classic
# matrix/matrix, vector/matrix or matrix/vector multiplication.
if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
and lhs_ndim - rhs_ndim in [-1, 0, 1]
and 1 <= lhs_ndim - len(lhs_batch) <= 2
and 1 <= rhs_ndim - len(rhs_batch) <= 2
and lhs_contracting == (len(lhs.shape) - 1,)
and rhs_contracting == (len(lhs_batch),)):
# All the inputs to tf.linalg.matmul must have 2 inner dimensions,
# after their batch dimensions, so we need to expand the dimensions
# appropriately. We can get to this branch with three combinations of
# inner shapes:
# - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
# - in this case, the resulting inner shape is [a, c];
# - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
# - in this case, we need to expand lhs to [1, b], and the resulting
# shape is [c]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [1, c];
# - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
# - in this case, we need to expand rhs to [b, 1], and the resulting
# shape is [a]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [a, 1];
# - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
# - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
# and the resulting shape is (). We need to squeeze the result of
# tf.linalg.matmul as it will have shape [1, 1].
squeeze_idxs = []
if lhs_ndim - len(lhs_batch) == 1:
lhs = tf.expand_dims(lhs, lhs_ndim - 1)
squeeze_idxs.append(len(lhs.shape) - 2)
if rhs_ndim - len(rhs_batch) == 1:
rhs = tf.expand_dims(rhs, rhs_ndim)
squeeze_idxs.append(len(rhs.shape) - 1)
result = tf.linalg.matmul(lhs, rhs)
if len(squeeze_idxs) != 0:
assert all([result.shape[i] == 1 for i in squeeze_idxs])
result = tf.squeeze(result, squeeze_idxs)
return result
new_id = iter(string.ascii_letters)
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
lhs_out_axis_ids = lhs_axis_ids[:]
rhs_out_axis_ids = rhs_axis_ids[:]
for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids = []
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids.append(shared_id)
not_none = lambda x: x is not None
out_axis_ids = list(filter(
not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids))
assert lhs.dtype == rhs.dtype
spec = "{},{}->{}".format("".join(lhs_axis_ids),
"".join(rhs_axis_ids),
"".join(out_axis_ids))
return tf.linalg.einsum(spec, lhs, rhs)
tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _broadcast(operand, *, sizes):
result_shape = tf.TensorShape(sizes).concatenate(operand.shape)
return tf.broadcast_to(operand, result_shape)
tf_impl[lax.broadcast_p] = _broadcast
def _broadcast_in_dim(operand, *, shape, broadcast_dimensions):
inshape = [1] * len(shape)
for orig_shape_i, broadcast_dim_i in zip(operand.shape, broadcast_dimensions):
if orig_shape_i != 1: inshape[broadcast_dim_i] = shape[broadcast_dim_i]
inshape_tf = _eval_shape(inshape)
shape_tf = _eval_shape(shape)
return tf.broadcast_to(tf.reshape(operand, inshape_tf), shape_tf)
tf_impl[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _reshape(operand, *, new_sizes, dimensions):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))
new_sizes_tf = _eval_shape(new_sizes)
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
tf_impl[lax.reshape_p] = _reshape
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
op_shape = _in_avals[0].shape
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
new_shape_tf = _eval_shape(new_shape)
return tf.reshape(operand, new_shape_tf)
tf_impl_with_avals[lax.squeeze_p] = _squeeze
def _pad(operand, padding_value, *, padding_config,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
del _in_avals
low, high, interior = util.unzip3(padding_config)
if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
return tf.pad(operand, util.safe_zip(low, high),
mode="CONSTANT", constant_values=padding_value)
if not _enable_xla:
raise _xla_path_disabled_error("pad")
out = tfxla.pad(operand, padding_value, low, high, interior)
# TODO(b/184499027): improve shape inference for XlaPad
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.pad_p] = _pad
def _rev(operand, *, dimensions):
return tf.reverse(operand, dimensions)
tf_impl[lax.rev_p] = _rev
tf_impl[lax.select_p] = tf.where
def _transpose(operand, *, permutation):
return tf.transpose(operand, perm=permutation)
tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
tf_impl[lax.reduce_sum_p] = (
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=0))
tf_impl[lax.reduce_prod_p] = (
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=0))
tf_impl[lax.reduce_max_p] = (
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=0))
tf_impl[lax.reduce_min_p] = (
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=0))
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
def _argminmax(fn, operand, axes, index_dtype):
axis, = axes
output_type = tf.int32
if dtypes.iinfo(index_dtype).bits > 32:
output_type = tf.int64
# TODO(phawkins): handle axes larger than 2^31.
result = fn(operand, axis=axis, output_type=output_type)
return tf.cast(result, to_tf_dtype(index_dtype))
tf_impl[lax.argmin_p] = functools.partial(_argminmax, tf.math.argmin)
tf_impl[lax.argmax_p] = functools.partial(_argminmax, tf.math.argmax)
_add_fn = tf.function(_add, autograph=False)
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
def _select_and_gather_add(tangents: TfVal,
operand: TfVal,
select_prim: core.Primitive,
window_dimensions: Sequence[int],
window_strides: Sequence[int],
base_dilation: Sequence[int],
window_dilation: Sequence[int],
padding: Sequence[Tuple[int, int]],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
# Note: this function follows the pattern in
# jax.lax._select_and_gather_add_translation.
dtype = operand.dtype
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits
# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
# we thus intend to let the code throw a different exception on this platform.
max_bits = 64
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda dtype, x: tf.constant(np.array(x), dtype)
if double_word_reduction:
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
# Packs two values into a tuple.
def pack(a, b):
a = _bitcast_convert_type(a, word_dtype)
b = _bitcast_convert_type(b, word_dtype)
a = _convert_element_type(a, new_dtype=double_word_dtype)
b = _convert_element_type(b, new_dtype=double_word_dtype)
a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits))
return tf.bitwise.bitwise_or(a, b)
# Unpacks the first element of a tuple.
def fst(t):
assert t.dtype == double_word_dtype
st = _shift_right_logical(t, const(double_word_dtype, nbits))
return _bitcast_convert_type(
_convert_element_type(st, new_dtype=word_dtype), dtype
)
# Unpacks the second element of a tuple.
def snd(t):
return _bitcast_convert_type(
_convert_element_type(t, new_dtype=word_dtype), dtype
)
else:
raise NotImplementedError(f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits.")
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
def reducer(x, y):
which = tf_impl[select_prim]
return tf_impl[lax.select_p](which(fst(x), fst(y)), x=x, y=y)
init = -np.inf if select_prim is lax.ge_p else np.inf
init_identity = lambda x: pack(const(dtype, init), const(dtype, 0))
out = _specialized_reduce_window(reducer, init_identity,
pack(operand, tangents),
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation,
_in_avals=_in_avals, _out_aval=_out_aval)
return snd(out)
tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
def _get_shape_from_tensor_or_array(x):
if isinstance(x.shape, tf.TensorShape):
return tuple(x.shape.as_list())
return tuple(x.shape)
def _common_reduce_window(operand, init_val, reducer, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval):
if not _enable_xla:
raise _xla_path_disabled_error("reduce_window")
o_spec = tf.TensorSpec((), dtype=operand.dtype)
reducer_fn = tf.function(reducer, autograph=False).get_concrete_function(o_spec, o_spec)
if not isinstance(init_val, tf.Tensor):
assert not config.jax_enable_checks or _is_tfval(init_val), f"Non TfVal: {init_val}"
init_val = tf.constant(init_val, operand.dtype)
out = tfxla.reduce_window(operand, init_val,
reducer_fn, window_dimensions,
window_strides, base_dilations=base_dilation,
window_dilations=window_dilation, padding=padding)
# TODO: implement shape inference for XlaReduceWindow
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation,
_in_avals, _out_aval):
"""TensorFlow implementation of reduce_window.
Args:
operand: N dimensional array containing elements of type T
init_value: starting value of the reduction
jaxpr: the jaxpr corresponding to the reduction function
consts: the constants associated with jaxpr.
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
Returns:
The reduced operand.
"""
assert len(consts) == 0, "Reduction computation cannot have constants"
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2)
return res
return _common_reduce_window(
operand, init_value, reducer, window_dimensions, window_strides, padding,
base_dilation, window_dilation, _in_avals, _out_aval
)
# _try_tf_pool returns a Tensor when it succeeds, or a string describing why
# it did not succeed otherwise. It currently only supports reduce_window_max
# and reduce_window_sum.
# TODO(bchetioui): this function is not exhaustive wrt which
# reduce_window_max or reduce_window_sum cases can be translated into a call to
# max_pool or avg_pool. Further investigation is needed to fully flesh it out.
def _try_tf_pool(op_name, operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation) -> Union[str, TfVal]:
# Contrarily to the main path, tf.int8 is actually a valid type for
# tf.nn.max_pool.
if op_name == "reduce_window_max" and operand.dtype in [
tf.bool, tf.uint32, tf.uint64, tf.complex64, tf.complex128
]:
return f"tf.nn.max_pool does not support operands of type {operand.dtype}"
if op_name == "reduce_window_sum" and operand.dtype not in [
tf.float16, tf.float32, tf.float64
]:
return f"tf.nn.avg_pool does not support operands of type {operand.dtype}"
has_batch_dim = window_dimensions[0] == 1
has_channel_dim = window_dimensions[-1] == 1
nb_spatial_dimensions = len(operand.shape) - has_batch_dim - has_channel_dim
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
return ("TensorFlow can only handle pooling for arrays with 1, 2, or "
"3 spatial dimensions")
# TODO(bchetioui): does a simple conversion with another base dilation exist?
if list(base_dilation) != [1] * len(operand.shape):
return "Unimplemented support for base dilation"
# TODO(bchetioui): does a simple conversion with another window_dilation
# exist? The whole story seems similar to convolution.
if list(window_dilation) != [1] * len(operand.shape):
return "Unimplemented support for window dilation"
if list(padding) != [(0, 0)] * len(operand.shape):
return "Unimplemented support for padding"
# ReduceWindow in XLA takes an array of rank N as a parameter, but
# tf.nn.max_pool / tf.nn.avg_pool take an array of rank N+2, with a default
# shape of the form [batch_size] + input_spatial_shape + [num_channels]
tf_operand = operand
tf_window_dimensions = list(window_dimensions)
tf_window_strides = list(window_strides)
if not has_batch_dim:
tf_operand = tf.expand_dims(tf_operand, 0)
tf_window_dimensions = [1] + tf_window_dimensions
tf_window_strides = [1] + tf_window_strides
if not has_channel_dim:
tf_operand = tf.expand_dims(tf_operand, -1)
tf_window_dimensions.append(1)
tf_window_strides.append(1)
tf_data_format = "N" + "DHW"[-nb_spatial_dimensions:] + "C"
tf_padding = "VALID"
if op_name == "reduce_window_max":
result = tf.nn.max_pool(tf_operand, tf_window_dimensions, tf_window_strides,
tf_padding, tf_data_format)
elif op_name == "reduce_window_sum":
avg = tf.nn.avg_pool(tf_operand, tf_window_dimensions, tf_window_strides,
tf_padding, tf_data_format)
result = avg * np.prod(tf_window_dimensions)
else:
return f"Unimplemented support for {op_name}"
if not has_batch_dim:
result = tf.squeeze(result, 0)
if not has_channel_dim:
result = tf.squeeze(result, -1)
return result
def _specialized_reduce_window(reducer, identity, operand, *, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval,
name=None):
"""Wraps the TensorFlow reduce window operation based on a reducer and an
identity function defining the initial value of the reduction depending on
the dtype of the operand.
Args:
reducer: reduction function of type TfVal -> TfVal -> TfVal
identity: function that takes a TensorFlow dtype as a parameter and returns
the starting value of the reduction.
operand: N dimensional array containing elements of type T
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
name: the name of the specialized reduce window primitive for which this
conversion function is called. This information may help to choose a
different conversion path (optional)
Returns:
The reduced operand.
"""
if name in ["reduce_window_max", "reduce_window_sum"]:
res = _try_tf_pool(name, operand, window_dimensions, window_strides,
padding, base_dilation, window_dilation)
if not isinstance(res, str):
return res
return _common_reduce_window(
operand, identity(operand.dtype), reducer, window_dimensions,
window_strides, padding, base_dilation, window_dilation, _in_avals,
_out_aval
)
def _get_max_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(-np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).min
else:
assert dtypes.issubdtype(numpy_tf_dtype, np.bool_), (
f"{tf_dtype} has no defined max identity"
)
return False
def _get_min_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).max
else:
assert dtypes.issubdtype(numpy_tf_dtype, np.bool_), (
f"{tf_dtype} has no defined min identity"
)
return True
# pylint: disable=protected-access
tf_impl_with_avals[lax.reduce_window_sum_p] = (
functools.partial(_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
tf_impl_with_avals[lax.reduce_window_min_p] = (
functools.partial(_specialized_reduce_window, tf.math.minimum,
_get_min_identity, name="reduce_window_min"))
tf_impl_with_avals[lax.reduce_window_max_p] = (
functools.partial(_specialized_reduce_window, tf.math.maximum,
_get_max_identity, name="reduce_window_max"))
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# pylint: enable=protected-access
# We use lax_control_flow._cumred_tpu_translation_rule to convert cummax,
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min), multiple_results=False)
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max), multiple_results=False)
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
# certain dtypes: bfloat16, float16, float32, float64, and int32. Other dtypes
# will fail when running in compiled mode, but are otherwise compatible with
# the operation. A non-XLA path can thus be defined for all dtypes, though the
# tests will crash.
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum), multiple_results=False)
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod), multiple_results=False)
def _select_and_scatter(
operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
@functools.partial(bool_to_int8, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
if not _enable_xla:
raise _xla_path_disabled_error("select_and_scatter_add")
init_value = tf.zeros((), operand.dtype)
select_fn = (tf.function(tf_impl[select_prim], autograph=False)
.get_concrete_function(init_value, init_value))
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
padding, source, init_value, select_fn,
scatter_fn)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
res = _convert_jax_impl(
functools.partial(jax._src.random._threefry2x32_lowering,
use_rolled_loops=False),
multiple_results=True)(*args, _in_avals=_in_avals, _out_aval=_out_aval)
return res
tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
# Use the vmap implementation, otherwise on TPU the performance is really bad
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
functools.partial(jax._src.random._gamma_impl, use_vmap=True),
multiple_results=False)
def _gather_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.GatherDimensionNumbers()
proto.offset_dims.extend(dimension_numbers.offset_dims)
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
proto.start_index_map.extend(dimension_numbers.start_index_map)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
@functools.partial(bool_to_int8, argnums=0)
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
_in_avals, _out_aval):
"""Tensorflow implementation of gather."""
del _in_avals
if not _enable_xla:
raise _xla_path_disabled_error("gather")
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes)
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf, False)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.gather_p] = _gather
def _slice(operand, start_indices, limit_indices, strides,
_in_avals, _out_aval):
if strides is None:
strides = [1] * len(start_indices)
slices = tuple(map(slice,
_eval_shape(start_indices),
_eval_shape(limit_indices),
_eval_shape(strides)))
out = operand[slices]
# TODO(b/184503314): improve shape inference for __getitem__
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.slice_p] = _slice
def _dynamic_slice(operand, *start_indices, slice_sizes,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
# Here we could use tf.slice. Similarly, for lax.gather we can sometimes use
# tf.gather. But those have different semantics for index-out-of-bounds than
# JAX (and XLA). We have tried to force compilation, by wrapping into
# tf.xla.experimental.compile, or tf.function(jit_compile=True), but
# those solutions are brittle because they do not work when nested into an
# outer compilation (see b/162814494 and b/163006262). They also do not
# survive well being put in a SavedModel. Hence, we now use TFXLA slicing
# and gather ops.
if not _enable_xla:
raise _xla_path_disabled_error("dynamic_slice")
res = tfxla.dynamic_slice(operand, tf.stack(start_indices),
size_indices=_eval_shape(slice_sizes))
# TODO: implement shape inference for XlaDynamicSlice
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.ScatterDimensionNumbers()
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
proto.scatter_dims_to_operand_dims.extend(
dimension_numbers.scatter_dims_to_operand_dims)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
def _scatter(operand, scatter_indices, updates, *,
update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
del unique_indices, _in_avals
assert len(update_consts) == 0, "Update computation cannot have constants"
if not _enable_xla:
raise _xla_path_disabled_error("scatter")
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2)
return res
o_spec = tf.TensorSpec((), dtype=operand.dtype)
xla_update_computation = (
tf.function(update_computation, autograph=False).get_concrete_function(o_spec, o_spec))
out = tfxla.scatter(operand, scatter_indices, updates, xla_update_computation, proto,
indices_are_sorted=indices_are_sorted)
# TODO: implement shape analysis for XlaScatter
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.scatter_p] = _scatter
tf_impl_with_avals[lax.scatter_min_p] = _scatter
tf_impl_with_avals[lax.scatter_max_p] = _scatter
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
tf_impl_with_avals[lax.scatter_add_p] = _scatter
def _dynamic_update_slice(operand, update, *start_indices):
if not _enable_xla:
raise _xla_path_disabled_error("dynamic_update_slice")
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
tf_impl[lax.dynamic_update_slice_p] = _dynamic_update_slice
def _cond(index: TfVal, *operands: TfVal,
branches: Sequence[core.ClosedJaxpr],
linear: Sequence[bool]) -> Sequence[TfVal]:
del linear
# tf.cond needs lambdas with no arguments.
branches_tf = [functools.partial(_interpret_jaxpr, jaxpr, *operands)
for jaxpr in branches]
return tf.switch_case(index, branches_tf)
tf_impl[lax_control_flow.cond_p] = _cond
def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
cond_consts, body_consts, init_carry = util.split_list(args, [cond_nconsts,
body_nconsts])
if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined]
# The conditional is not a scalar, this must be a batched while
return _batched_cond_while(*args,
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
# The conditional must return a single value to TF
def cond_tf_func(*args: TfVal) -> TfVal:
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args)
return pred
body_tf_func = functools.partial(_interpret_jaxpr, body_jaxpr, *body_consts)
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
def _batched_cond_while(*args: TfVal,
cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
body_nconsts: int, body_jaxpr: core.ClosedJaxpr
) -> Sequence[TfVal]:
"""Interprets a while_loop with a batched condition.
A batched while has a conditional that returns a tensor of booleans, and
a body that returns a list of tensors whose leading dimensions match those
of the conditional tensor.
We need to turn it into a while with scalar boolean conditional. We will
expand the loop carry to include a prefix with the current tensor boolean
condition. We prepend to the loop the first calculation of the tensor boolean
condition. The loop condition will use a "reduce_any" to calculate a scalar
boolean from the tensor boolean condition. The end of the loop body will
compute the new carry using a "tf.where", and we compute the new tensor
boolean condition.
"""
cond_consts, body_consts, init_carry = util.split_list(args, [cond_nconsts,
body_nconsts])
# Initial computation of batched condition
init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry)
assert init_pred_b is not core.unit
def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal:
pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape))))
return pred
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr,
*body_consts, *carry)
def select_one_carry(new_c: TfVal, c: TfVal) -> TfVal:
pred_b_bcast = _broadcast_in_dim(pred_b,
shape=new_c.shape,
broadcast_dimensions=list(range(len(pred_b.shape))))
return tf.where(pred_b_bcast, new_c, c)
selected_carry: Sequence[TfVal] = list(
util.safe_map(select_one_carry, new_carry, carry))
next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry)
return (next_pred_b, *selected_carry)
_, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func,
(init_pred_b, *init_carry))
return res_carry
tf_impl[lax_control_flow.while_p] = _while
# We use the scan impl rule to rewrite in terms of while.
tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)
def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
# Some types originally incompatible with tf.math.top_k can be promoted
# to a compatible type without loss of precision.
def promote_tf_dtype(tf_dtype):
if tf_dtype in [tf.bool, tf.uint8, tf.uint16]:
return tf.uint32
if tf_dtype in [tf.int8, tf.int16]:
return tf.int32
if tf_dtype is tf.float16:
return tf.float32
return None
conversion_dtype = promote_tf_dtype(operand.dtype)
if conversion_dtype:
values, indices = tf.math.top_k(tf.dtypes.cast(operand, conversion_dtype),
k=k, sorted=True)
return tf.dtypes.cast(values, operand.dtype), indices
else:
return tf.math.top_k(operand, k=k, sorted=True)
tf_impl[lax.top_k_p] = _top_k
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
num_keys: int) -> Tuple[TfVal, ...]:
if not _enable_xla:
raise _xla_path_disabled_error("sort")
assert 1 <= num_keys <= len(operands)
assert 0 <= dimension < len(
operands[0].shape
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator_old(*tf_args: TfVal) -> TfVal:
assert len(tf_args) == 2 * len(operands)
# We build a comparison:
# arg[0] < arg[1] or (arg[0] == arg[1] and (arg[2] < arg[3] or ...))
# all the way to arg[2 * num_keys - 2] < arg[2 * num_keys - 1]
inside_comparison = None
for key_idx in range(num_keys - 1, -1, -1):
a = tf_args[2 * key_idx]
b = tf_args[2 * key_idx + 1]
a_lt_b = tf.math.less(a, b)
if inside_comparison is None:
inside_comparison = a_lt_b
else:
inside_comparison = tf.math.logical_or(
a_lt_b, tf.math.logical_and(tf.math.equal(a, b), inside_comparison))
return inside_comparison
comparator_spec: List[tf.TensorSpec] = []
comparator_jax_in_avals: List[core.AbstractValue] = []
for op in operands:
o_spec = tf.TensorSpec((), dtype=op.dtype)
comparator_spec.extend([o_spec, o_spec])
o_aval = core.ShapedArray((), to_jax_dtype(op.dtype))
comparator_jax_in_avals.extend([o_aval, o_aval])
# Use the same comparator that JAX uses when compiling to XLA, to get the
# proper NaN/Inf total order, and the lexicographic ordering.
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator(*tf_args: TfVal) -> TfVal:
return _convert_jax_impl(
lax._sort_lt_comparator, multiple_results=False)(
*tf_args,
_in_avals=comparator_jax_in_avals,
_out_aval=core.ShapedArray((), np.bool_),
num_keys=num_keys)
xla_comparator_computation = (
tf.function(lexicographic_comparator,
autograph=False).get_concrete_function(*comparator_spec))
results = tfxla.variadic_sort(operands, dimension=dimension,
is_stable=is_stable,
comparator=xla_comparator_computation)
return results
tf_impl[lax.sort_p] = _sort
def _fft(x, fft_type, fft_lengths):
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
if fft_type == IRFFT:
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
else:
expected_lengths = x.shape[-len(fft_lengths):]
if expected_lengths != fft_lengths:
raise NotImplementedError(
f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of "
f"array with shape={x.shape}.")
tf_funcs = {FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]}
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
tf_impl[lax_fft.fft_p] = _fft
def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)
tf_impl[lax_linalg.qr_p] = _qr
def _svd(operand, full_matrices, compute_uv):
result = tf.linalg.svd(operand, full_matrices, compute_uv)
if not compute_uv:
return result,
s, u, v = result
return s, u, tf.linalg.adjoint(v)
tf_impl[lax_linalg.svd_p] = _svd
def _eig(operand: TfVal, compute_left_eigenvectors: bool,
compute_right_eigenvectors: bool):
if compute_left_eigenvectors and compute_right_eigenvectors:
# TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to
# sort the left eigenvectors in the right order. The jax.numpy.linalg API
# suggests to me that left eigenvectors are anyway seldom used, so I
# think it is acceptable to leave as unimplemented for now.
msg = ("Conversion of eig is not implemented when both "
"compute_left_eigenvectors and compute_right_eigenvectors are set "
"to True.")
raise NotImplementedError(msg)
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
return tuple([tf.linalg.eigvals(operand)])
elif compute_right_eigenvectors:
return tuple(tf.linalg.eig(operand))
else: # compute_left_eigenvectors == True
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
wHH = tf.math.conj(wH)
return tuple([wHH, vl])
tf_impl[lax_linalg.eig_p] = _eig
def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval):
if operand.shape[-1] == 0:
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
else:
if not lower:
operand = tf.linalg.adjoint(operand)
w, v = tf.linalg.eigh(operand)
cast_type = { tf.complex64: tf.float32,
tf.complex128: tf.float64 }.get(operand.dtype)
if cast_type is not None:
w = tf.cast(w, cast_type)
return v, w
tf_impl_with_avals[lax_linalg.eigh_p] = _eigh
def _lu(operand: TfVal, _in_avals, _out_aval):
return _convert_jax_impl(lax_linalg._lu_python)(operand, _in_avals=_in_avals,
_out_aval=_out_aval)
tf_impl_with_avals[lax_linalg.lu_p] = _lu
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
transpose_a: bool, conjugate_a: bool,
unit_diagonal: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
if unit_diagonal:
a_aval, _ = _in_avals
a_shape = _eval_shape(a_aval.shape)
a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype))
if not left_side:
rank = len(a.shape)
transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2]
a = tf.transpose(a, transpose_dimensions)
b = tf.transpose(b, transpose_dimensions)
lower = not lower
# adjoint == transpose for real dtypes, so special care need only be taken
# for complex types.
if a.dtype in [tf.complex64, tf.complex128]:
if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a):
a = tf.math.conj(a)
result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a)
if not left_side:
result = tf.transpose(result, transpose_dimensions)
return result
tf_impl_with_avals[lax_linalg.triangular_solve_p] = _triangular_solve
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl)(
*args, const_lengths=const_lengths, jaxprs=jaxprs, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[lax_control_flow.linear_solve_p] = _linear_solve
def _custom_jvp_call_jaxpr(*args: TfVal,
fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args)
tf_impl[custom_derivatives.custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr
def _custom_vjp_call_jaxpr(*args: TfVal,
fun_jaxpr: core.ClosedJaxpr,
**_) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args)
tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr
def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
"function.")
tf_impl[ad.custom_lin_p] = _custom_lin
def split_to_logical_devices(
tensor: TfVal,
partition_dimensions: pxla.PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object
through the generated computation. It seems that the original function needs
the strategy object only for error checking, which we assume is done upstream
by JAX.
Args:
tensor: Input tensor to annotate.
partition_dimensions: A list of integers, with one integer per tensor
dimension, specifying in how many parts the dimension should be split. The
product of integers must equal the number of devices per replica.
use_sharding_op: whether to use a sharding op, or not.
Returns:
an annotated tensor.
"""
# This corresponds to the sharding annotations in
# xla_bridge._sharding_to_proto.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = np.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
in_parts: Sequence[pxla.PartitionsOrReplicated],
out_parts_thunk,
**_) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
sharded_vals = util.safe_map(split_to_logical_devices, vals, in_parts)
vals_out = f.call_wrapped(*sharded_vals) # caller handles new_sublevel
out_parts_flat = out_parts_thunk()
assert len(out_parts_flat) == len(vals_out), f"expected {len(out_parts_flat)} == {len(vals_out)}"
sharded_vals_out = [
(split_to_logical_devices(val, val_part), val_aval)
for (val, val_aval), val_part in util.safe_zip(vals_out, out_parts_flat)
]
return sharded_vals_out
def _sharding_constraint(arg: TfVal, *,
partitions: pxla.PartitionsOrReplicated):
return split_to_logical_devices(arg, partitions)
tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()
# The types here are automagically changed by TensorFlow's checkpointing
# infrastructure.
m.a = (tf.Module(), tf.Module())
m.b = [tf.Module(), tf.Module()]
m.c = {"a": tf.Module()}
tuple_wrapper = type(m.a)
list_wrapper = type(m.b)
dict_wrapper = type(m.c)
# TF AutoTrackable swaps container types out for wrappers.
assert tuple_wrapper is not tuple
assert list_wrapper is not list
assert dict_wrapper is not dict
jax.tree_util.register_pytree_node(
tuple_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: tuple(xs))
jax.tree_util.register_pytree_node(
list_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: list(xs))
jax.tree_util.register_pytree_node(
dict_wrapper,
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict(zip(k, xs)))
_register_checkpoint_pytrees()