George Necula dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00

2806 lines
108 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
from functools import partial
import contextlib
import os
import re
import string
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import jax
from jax._src import ad_util
from jax import api_util, config
from jax._src import api
from jax import core, custom_derivatives, dtypes
from jax import linear_util as lu
from jax import random, tree_util
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import fft as lax_fft
from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg
import jax._src.random
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import pxla
from jax.interpreters import sharded_jit
from jax.interpreters import xla
from jax.lib import xla_client
from . import shape_poly
import numpy as np
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
PolyShape = shape_poly.PolyShape
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/>-]")
map = util.safe_map
zip = util.safe_zip
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
if not _VALID_SCOPE_REGEX.match(scope_name):
scope_name = ".{}".format(scope_name)
return scope_name
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# function arguments.
_ShapeEnv = Dict[str, TfVal]
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
# The implementation rules for primitives. The rule will be called with the
# arguments (TfVal) and must return TfVal (or a sequence thereof,
# if primitive.multiple_results). The vast majority of primitives do not need
# to worry about core.unit inputs or results. The exception are primarily the
# control-flow primitives.
tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
# Some primitive implementation rules need the abstract values of arguments
# and the results. This is the case for the primitives implemented using
# _convert_jax_impl and those that need to adjust the shape of the outputs
# due to missing TF shape inference rules for TFXLA ops. The rules for these
# primitives should be added to `tf_impl_with_avals`.
# The abstract value are passed to the implementation as two special kwargs
# `_in_avals` (a tuple of core.AbstractValue) and `_out_aval` (a
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
self.name_stack = ""
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
self.enable_xla = True
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions
self.shape_env: _ShapeEnv = {}
# Whether to actually include XLA op metadata in the generated TF ops
self.include_xla_op_metadata = True
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack():
return _thread_local_state.name_stack
def _xla_disabled_error(primitive_name: str,
extra_msg: Optional[str] = None) -> Exception:
assert not _thread_local_state.enable_xla
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
if extra_msg:
msg += f" {extra_msg}"
return NotImplementedError(msg)
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable,
*,
polymorphic_shapes: Optional[Sequence[Any]] = None,
with_gradient=True,
enable_xla=True
) -> Callable:
"""Transforms `fun` to be executed by TensorFlow.
See
[README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun: Function to be transformed. Its arguments and return value should be
JAX arrays, or nested standard Python containers (tuple/list/dict) thereof
(pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during conversion.
.. warning:: The shape-polymorphic conversion is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change. It
should be a Python object with the same pytree structure as, or a prefix
of, the tuple of arguments to the function, but with a shape
specification corresponding to each argument. The default value is
`None`, which is a shortcut for a tuple of `None` one for each argument,
denoting that all shapes are monomorphic.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders. For convenience, a shape specification can also be given
as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The conversion fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
with_gradient: if set, will add a tf.custom_gradient to the converted
function, by converting the ``jax.vjp(fun)``. Only first-order
differentiation is supported for now. If the converted function is saved
in a SavedModel, the custom gradients are currently lost and an error will
be raised if a gradient computation is attempted. This is due to a current
bug in TensorFlow.
enable_xla: if unset, the converter will try harder to use pure TF ops to
convert the function, and raise an error if it can not be converted
without resorting to XLA ops (default: True).
Returns:
A version of `fun` that expects TfVals as arguments (or
tuple/lists/dicts) thereof, and returns TfVals as outputs.
"""
api._check_callable(fun)
fun_name = getattr(fun, "__name__", "unknown")
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError("convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
def check_arg(a):
if not _is_tfval(a):
msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
tree_util.tree_map(check_arg, args)
tree_util.tree_map(check_arg, list(kwargs.values()))
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
# May need to cast the arguments to have the type assumed by JAX
args_and_dtypes_flat = tuple(map(_tfval_to_tensor_jax_dtype, args_flat))
args_flat, arg_dtypes_flat = util.unzip2(args_and_dtypes_flat)
# Name input tensors; do this after we have cast the arguments
def _apply_name(a: TfVal, suffix) -> TfVal:
return tf.identity(a, f"jax2tf_arg_{suffix}")
args_flat = tuple(_apply_name(a, i) for i, a in enumerate(args_flat))
if polymorphic_shapes is None:
polymorphic_shapes_ = (None,) * len(args)
else:
if not isinstance(polymorphic_shapes, Sequence) or len(args) != len(polymorphic_shapes):
msg = ("polymorphic_shapes must be a sequence with the same length as the positional argument list "
f"({len(args)}). Got polymorphic_shapes={polymorphic_shapes}.")
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
# Expand the polymorphic_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0],
polymorphic_shapes_))
# Add kwargs shapes.
polymorphic_shapes_flat = polymorphic_shapes_flat + tuple(
(None,) * (len(args_flat) - len(polymorphic_shapes_flat)))
# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the polymorphic_shapes if given. May create new shape
# variables. May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
args_avals_flat, shapeenv = _args_to_avals_and_env(
args_flat, arg_dtypes_flat, polymorphic_shapes_flat)
# This function may take pytrees of TfVals. We can only set
# tf.custom_gradient on functions that take a flat argument list.
f = lu.wrap_init(fun)
# out_tree_thunk() will be the output tree, after running _interpret_fun.
flat_fun, out_tree_thunk = flatten_fun(f, in_tree)
# Prepare the grad_fn for tf.custom_gradient.
def converted_grad_fn(*out_cts_flat: TfVal,
_out_cts_avals: Sequence[core.AbstractValue],
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"variables={variables}")
def fun_vjp_jax(args_jax, out_cts_jax):
# One may think that we can get the pullback while we are converting
# the main function in the first place. That is problematic, because the
# pullback may contain captured tracers from the conversion of the
# main function. Those tracers will confuse the conversion of the
# pullback. So, we construct the vjp anew.
_, pullback_jax = jax.vjp(fun, *args_jax)
return pullback_jax(out_cts_jax)
if polymorphic_shapes is None:
vjp_polymorphic_shapes = None
else:
args_polymorphic_shapes = tree_util.tree_unflatten(
in_tree.children()[0], polymorphic_shapes_flat)
out_cts_polymorphic_shapes = tree_util.tree_unflatten(
out_tree_thunk(),
tuple(str(out_aval.shape)
for out_aval in _out_cts_avals)) # type: ignore
vjp_polymorphic_shapes = [
args_polymorphic_shapes, out_cts_polymorphic_shapes
]
out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat)
# TODO: enable higher-order gradients
# TODO: I think that this does not work with kwargs
with tf.name_scope("jax2tf_vjp"):
in_cts = convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args, out_cts)
return in_cts
try:
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
prev_enable_xla = _thread_local_state.enable_xla
_thread_local_state.enable_xla = enable_xla
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
_thread_local_state.include_xla_op_metadata = False
_thread_local_state.shape_env = shapeenv
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
name_stack)
outs, out_avals = util.unzip2(out_with_avals)
return (tuple(outs),
partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))
out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
else:
out_flat_raw = _interpret_fun(flat_fun, args_flat, args_avals_flat,
name_stack)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
out_flat = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o, _ in out_flat_raw
]
finally:
_thread_local_state.shape_env = {}
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
return out
return converted_fun
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
# Internals
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
prev_name_stack = _thread_local_state.name_stack
if extra_name_stack:
if not prev_name_stack:
_thread_local_state.name_stack = extra_name_stack
else:
_thread_local_state.name_stack = util.extend_name_stack(
_thread_local_state.name_stack, extra_name_stack)
try:
yield
finally:
_thread_local_state.name_stack = prev_name_stack
def _interpret_fun(
fun: lu.WrappedFun, in_vals: Sequence[TfVal],
in_avals: Sequence[core.AbstractValue],
extra_name_stack: Optional[str]
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = \
fun.call_wrapped(*in_vals)
del main
return tuple(out_vals)
def _convert_jax_impl(jax_impl: Callable, *,
multiple_results=True,
extra_name_stack: Optional[str] = None) -> Callable:
"""Convert the JAX implementation of a primitive.
Args:
jax_impl: typically the impl-rule for a primitive, with signature
`(*args: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
a primitive in terms of other primitives.
multiple_results: whether `jax_impl` returns a sequence of results.
extra_name_stack: additional element to add to the name stack for the
converted ops.
Returns:
a function with signature `(*args: TfVal, _in_avals, _out_aval, **kwargs)
-> Sequence[TfVal]`.
"""
def wrapped(*tf_args: TfVal, _in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue,
**kwargs) -> Sequence[TfVal]:
# We wrap the jax_impl under _interpret_fun to abstract the TF values
# from jax_impl and turn them into JAX abstract values.
def jax_impl_jax_args(*jax_args):
jax_results = jax_impl(*jax_args, **kwargs)
return jax_results if multiple_results else [jax_results]
tf_results_with_avals = _interpret_fun(
lu.wrap_init(jax_impl_jax_args), tf_args, _in_avals,
extra_name_stack)
tf_results, _ = util.unzip2(tf_results_with_avals)
return tf_results if multiple_results else tf_results[0]
return wrapped
@lu.transformation
def _interpret_subtrace(main: core.MainTrace,
in_avals: Sequence[core.AbstractValue],
*in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(
TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals))
# The outs may be core.unit, see comment in TensorFlowTrace.pure.
outs = yield in_tracers, {} # type: Sequence[Union[TfVal, core.Unit]]
out_tracers: Iterable[TensorFlowTracer] = (
map(trace.full_raise, outs)) # type: ignore
out_vals_with_avals: Sequence[Tuple[TfVal, core.AbstractValue]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal,
extra_name_stack: Optional[str]) -> Sequence[TfVal]:
"""Evaluates a Jaxpr with tf.Tensor arguments.
The output is a sequence of TfVal (no `core.unit`), suitable for use with TF.
"""
fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
out_with_avals = _interpret_fun(fun, args, jaxpr.in_avals, extra_name_stack)
return tuple(v for v, _ in out_with_avals)
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
aval.shape)) # type: ignore[attr-defined]
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
if jax_dtype == dtypes.float0:
jax_dtype = dtypes.bfloat16
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
return dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: Optional[DType] = None) -> Tuple[TfVal, DType]:
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
If `jax_dtype` is missing, uses JAX typing rules.
See README.md for details regarding 64-bit values.
Args:
val: a scalar, ndarray, tf.Tensor, or tf.Variable
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
rules for constants.
Returns:
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
conversion_dtype = _to_tf_dtype(jax_dtype)
if conversion_dtype != val.dtype:
return tf.cast(val, conversion_dtype), jax_dtype
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or xla.abstractify(val).dtype
conversion_dtype = _to_tf_dtype(jax_dtype)
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype
def _args_to_avals_and_env(
args: Sequence[TfVal],
arg_jax_dtypes: Sequence[DType],
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
Tuple[Sequence[core.AbstractValue], _ShapeEnv]:
"""Computes canonicalized args, abstract values and a dimension environment for arguments.
Args:
args: the arguments, TF inputs. Must be tf.Tensor or tf.Variable.
arg_dtypes: the inferred JAX dtypes for the args.
polymorphic_shapes: the polymorphic specifications for the arguments.
Returns: a tuple of: a sequence of abstract values corresponding to the
arguments, and a dimension environment.
"""
shapeenv: _ShapeEnv = {}
def input_aval(arg: TfVal,
arg_jax_dtype: DType,
polymorphic_shape: Optional[str]) -> core.AbstractValue:
"""The abstract value for an input."""
arg_shape = np.shape(arg)
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)
for i, d in enumerate(aval_shape):
if type(d) is int:
assert d == arg_shape[i]
elif shape_poly.is_poly_dim(d) and d not in shapeenv:
# Even if the shape of `arg` is known, we still use `tf.shape` for
# safety, because the promise is that we will convert the function
# to work for any value of the dimension.
v = d.to_var() # type: ignore
assert v is not None
shapeenv[v] = tf.shape(arg)[i] # type: ignore[index]
else:
# TODO: add an assertion tf.shape(arg)[i] == env[d]
pass
return core.ShapedArray(aval_shape, arg_jax_dtype)
avals = tuple(map(input_aval, args, arg_jax_dtypes, polymorphic_shapes)) # type: ignore
return avals, shapeenv
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
def shape_as_value(x):
"""Injects the shape of `x` as an array value.
**Experimental: please give feedback, and expect changes!**
This allows the use of a shape expression as array argument to JAX functions.
A typical example is for implementing a mean operation:
jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
"""
# return shape_as_value_p.bind(x)
return NotImplementedError("shape_as_value is deprecated")
# # TODO: move this to masking or to some common library, if approved
# shape_as_value_p = core.Primitive("shape_as_value")
# shape_as_value_p.multiple_results = True
# def _shape_as_value_impl(x):
# x_shape = np.shape(x)
# def dim_to_int(dim: shape_poly.DimSize) -> int:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is None:
# msg = ("shape_as_value is not implemented for non-constant shapes "
# "except for masking and jax2tf. "
# f"Has shape: {x_shape}")
# raise TypeError(msg)
# else:
# return dim_int
# return tuple(map(dim_to_int, x_shape))
#
# shape_as_value_p.def_impl(_shape_as_value_impl)
#
# def _shape_as_value_abstract(x_aval: core.AbstractValue) -> Sequence[core.AbstractValue]:
# rank = len(x_aval.shape) # type: ignore[attr-defined]
# return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
#
# shape_as_value_p.def_abstract_eval(_shape_as_value_abstract)
#
# def _shape_as_value_translation(comp, x):
# return xla_client._xla.ops.Tuple(comp,
# tuple(xb.constant(comp, d)
# for d in comp.GetShape(x).dimensions()))
#
# xla.translations[shape_as_value_p] = _shape_as_value_translation
#
# def _shape_as_value_jvp_rule(primals, tangents):
# # The shape does not depend on the contents of the input
# x, = primals
# zero = ad.Zero.from_value(0.)
# return shape_as_value(x), (zero,) * len(x.shape)
#
# ad.primitive_jvps[shape_as_value_p] = _shape_as_value_jvp_rule
#
# def _shape_as_value__batching_rule(batched_args, batch_dims):
# xv, = batched_args
# batch_dim, = batch_dims
# batch_size = xv.shape[batch_dim]
# batched_shape = shape_as_value(xv)
# one_shape = batched_shape[0:batch_dim] + batched_shape[batch_dim+1:]
# res = tuple(jnp.broadcast_to(d, (batch_size, 1)) for d in one_shape)
# return res, (0,) * len(one_shape)
#
# batching.primitive_batchers[shape_as_value_p] = _shape_as_value__batching_rule
#
# def _shape_as_value_masking_rule(operands, operands_logical_shapes):
# x_logical_shape, = operands_logical_shapes
# return tuple(x_logical_shape)
#
# masking.masking_rules[shape_as_value_p] = _shape_as_value_masking_rule
#
# def _shape_as_value_tf(x: TfVal,
# _in_avals: Sequence[core.AbstractValue],
# _out_aval: core.AbstractValue) -> TfVal:
# x_aval = _in_avals[0]
# def dim_to_tfval(dim: shape_poly.DimSize, dim_idx: int) -> TfVal:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is not None:
# return tf.convert_to_tensor(dim_int)
# else:
# return tf.shape(x)[dim_idx]
# return tuple(dim_to_tfval(dim, dim_idx)
# for dim_idx, dim in enumerate(x_aval.shape)) # type: ignore[attr-defined]
#
# tf_impl_with_avals[shape_as_value_p] = _shape_as_value_tf
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot
class TensorFlowTracer(core.Tracer):
"""Tracer class that boxes a TF value and a JAX abstract value.
In addition to the TF value we carry the JAX abstract value because there are
two cases when it cannot be recovered from the value: (a) when the abstract
value is core.abstract_unit, in which case the value is tf.nan; (b) when we
are converting with polymorphic shapes, in which case the shape of the value
may have dimensions set to `None`, which the JAX abstract value may contain
more precise information.
When the value has a partially-known shape, the dimensions marked as `None`
must correspond to non-constant dimensions in the abstract value.
See README.md for details.
"""
# val: TfVal
# _aval: core.AbstractValue
__slots__ = ["val", "_aval"]
def __init__(self, trace: "TensorFlowTrace", val: TfVal,
aval: core.AbstractValue):
self._trace = trace
self._aval = aval
if aval is core.abstract_unit:
self.val = val
elif isinstance(val, (tf.Tensor, tf.Variable)):
val_shape = val.shape
val_dtype = _to_jax_dtype(val.dtype)
aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined]
if config.jax_enable_checks:
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
for aval_dim, val_dim in zip(
self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not shape_poly.is_poly_dim(aval_dim):
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
else:
# We have a TF value with known shape, and the abstract shape is a shape variable.
try:
aval_int = int(_eval_shape([aval_dim])) # type: ignore
except TypeError:
continue
assert aval_int == val_dim, f"expected {self._aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
self.val = val
else: # Must be a numeric value
self.val = _tfval_to_tensor_jax_dtype(
val, self._aval.dtype)[0] # type: ignore[attr-defined]
@property
def aval(self):
return self._aval
def full_lower(self):
return self
class TensorFlowTrace(core.Trace):
"""Trace class that underlies the jax2tf transformation.
We are going to ensure that jax2tf.convert is never nested inside other
transformations. This is sufficient for intended use cases (converting
fully-transformed JAX code). It also simplifies our job because we do not have
to handle situations where we apply primitives on a mix of TF values and
JAX tracers from an outer transformation. E.g., for addition both the TF
values
and the JAX tracers have an override and they get confused if they see values
from the other world.
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
higher-order control-flow primitives we invoke recursively
_interpret_fun on the body of the conditional, which will create a nested TFT.
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT.
"""
def pure(self, val: Union[TfVal, core.Unit]) -> TensorFlowTracer:
"""Lifts a non-Tracer into the TensorFlowTracer.
This function may be called by way of trace.full_raise.
The value may be a core.unit. During JAX transformations we sometimes
produce a Jaxpr that has arguments of abstract value core.abstract_unit
and results equal to core.unit. These are arguments and results that are
not used in the computation.
In TF world, we represent core.unit as NaN. This is safe, as these values
should never be used.
"""
if val is core.unit:
return TensorFlowTracer(self, tf.constant(np.nan, tf.float32),
core.abstract_unit)
else:
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val)
return TensorFlowTracer(self, val,
core.ShapedArray(tf_val.shape, jax_dtype))
def lift(self, val: core.Tracer) -> TensorFlowTracer:
# This would be called when we need to raise a tracer from a lower-level
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
# inside another transform, there are no lower-level main traces.
assert False
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
# This is called when we need to raise a tracer from the same master,
# but a lower sublevel. This could come from a nested jit.
return TensorFlowTracer(self, val.val, val._aval)
def process_primitive(self, primitive: core.Primitive,
tracers: Sequence[TensorFlowTracer],
params) -> TensorFlowTracer:
impl, impl_needs_avals = self.get_primitive_impl(primitive)
args_avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
out_aval = primitive.abstract_eval(*args_avals, **params)
args_tf: Sequence[TfVal] = [t.val for t in tracers]
def invoke_impl() -> TfVal:
if impl_needs_avals:
return impl(
*args_tf,
_in_avals=args_avals, # type: ignore
_out_aval=out_aval,
**params)
else:
return impl(*args_tf, **params)
if _thread_local_state.include_xla_op_metadata:
op_metadata = xla.make_op_metadata(primitive, params,
name_stack=_get_current_name_stack(),
source_info=source_info_util.current())
op_metadata_proto = xla_data_pb2.OpMetadata(
op_type=op_metadata.op_type,
op_name=op_metadata.op_name,
source_file=op_metadata.source_file,
source_line=op_metadata.source_line
)
with tf_ops.get_default_graph()._attr_scope(
{"_XlaOpMetadata": attr_value_pb2.AttrValue(
s=op_metadata_proto.SerializeToString())}):
val_out = invoke_impl()
else:
val_out = invoke_impl()
if primitive.multiple_results:
out = [
TensorFlowTracer(self, v, a)
for v, a in zip(val_out, out_aval)
] # type: ignore
else:
out = TensorFlowTracer(self, val_out, out_aval) # type: ignore
# Check that the impl rule returned a value of expected shape and dtype
# TODO: adapt this to match polymorphic shapes
if config.jax_enable_checks:
if primitive.multiple_results:
for o, expected_aval in zip(out, out_aval): # type: ignore
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
f"{primitive}: out.aval = {o.aval}; expected {expected_aval}")
else:
assert out.aval == out_aval, ( # type: ignore
f"{primitive}: out.aval = {out.aval}; expected {out_aval}"
) # type: ignore
return out # type: ignore
def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers]
avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
fun = _interpret_subtrace(fun, self.main, avals)
extra_name_stack = None
if call_primitive == core.named_call_p:
extra_name_stack = util.wrap_name(params["name"], "named")
elif call_primitive == xla.xla_call_p:
extra_name_stack = util.wrap_name(params["name"], "jit")
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
vals_out: Sequence[Tuple[TfVal, core.AbstractValue]] = \
fun.call_wrapped(*vals)
elif call_primitive == sharded_jit.sharded_call_p:
vals_out = _sharded_call(fun, vals, **params)
else:
vals_out = fun.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
def post_process_call(self, call_primitive: core.Primitive,
out_tracers: Sequence[TensorFlowTracer], params):
# We encountered a call primitive, e.g., remat_call_p, whose result
# (out_tracers) include TensorFlowTracer that were not passed through
# its arguments (captured from the environment).
vals = tuple(t.val for t in out_tracers)
main = self.main
def todo(vals: Sequence[TfVal]):
# TODO: is name_stack correct?
trace = TensorFlowTrace(main, core.cur_sublevel())
return [
TensorFlowTracer(trace, v, out_tracer.aval)
for v, out_tracer in zip(vals, out_tracers)
]
return vals, todo
def process_map(self, map_primitive, f, tracers, params):
raise NotImplementedError("process_map")
def post_process_map(self, map_primitive, out_tracers, params):
raise NotImplementedError("post_process_map")
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del jvp # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_jvp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del fwd, bwd, out_trees # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
# Returns the primitive implementation and whether the implementation
# takes abstract values (see definition of tf_impl_with_avals)
try:
return tf_impl[p], False
except KeyError:
try:
return tf_impl_with_avals[p], True
except KeyError as err:
msg = "TensorFlow interpretation rule for '{}' not implemented"
raise NotImplementedError(msg.format(p)) from err
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
for unexpected in xla.call_translations: # Call primitives are inlined
if unexpected is pjit.pjit_p:
continue
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"reduce",
"rng_uniform",
"clz",
"igamma_grad_a",
"random_gamma_grad",
"reduce_precision",
# Not high priority?
"after_all",
"all_to_all",
"create_token",
"infeed",
"outfeed",
"pmax_p",
"pmin",
"ppermute",
"psum",
"pmax",
"pgather",
"axis_index",
"pdot",
"all_gather",
"lu_pivots_to_permutation",
"rng_bit_generator",
"xla_pmap",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
def _add(x: TfVal, y: TfVal) -> TfVal:
return tf.raw_ops.AddV2(x=x, y=y)
tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[xla.device_put_p] = lambda x, device=None: x
tf_impl[lax.neg_p] = tf.math.negative
def _sign(x: TfVal) -> TfVal:
if x.dtype.is_unsigned:
# TF and XLA do not support tf.math.sign for unsigned types.
return tf.where(
tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype),
tf.constant(1, dtype=x.dtype))
else:
return tf.math.sign(x)
tf_impl[lax.sign_p] = _sign
tf_impl[lax.floor_p] = tf.math.floor
tf_impl[lax.ceil_p] = tf.math.ceil
def _round(operand, *, rounding_method):
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
sign = _sign(operand)
operand *= sign
floor = tf.math.floor(operand)
operand -= floor
cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype))
return sign * (
tf.where(cond, tf.constant(np.array(1), operand.dtype),
tf.math.round(operand)) + floor)
else:
return tf.math.round(operand)
tf_impl[lax.round_p] = _round
tf_impl[lax.nextafter_p] = tf.math.nextafter
def _population_count(x):
orig_dtype = x.dtype
return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype)
tf_impl[lax.population_count_p] = _population_count
tf_impl[lax.is_finite_p] = tf.math.is_finite
def _abs(x: TfVal) -> TfVal:
# TF and XLA do not support tf.math.abs for unsigned types.
return tf.math.abs(x) if not x.dtype.is_unsigned else x
tf_impl[lax.abs_p] = _abs
tf_impl[lax.pow_p] = tf.math.pow
def _integer_pow(x, *, y: int, _in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
# Follows the implementation in lax._integer_pow_translation_rule
if y == 0:
return tf.broadcast_to(
tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape))
is_reciprocal = y < 0
if is_reciprocal:
y = -y
acc = None
while y > 0:
if y & 1:
acc = x if acc is None else tf.math.multiply(acc, x)
y >>= 1
if y > 0:
x = tf.math.multiply(x, x)
return tf.math.reciprocal(acc) if is_reciprocal else acc
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
tf_impl[lax.tan_p] = tf.math.tan
tf_impl[lax.tanh_p] = tf.math.tanh
tf_impl[lax.sin_p] = tf.math.sin
tf_impl[lax.sinh_p] = tf.math.sinh
tf_impl[lax.cos_p] = tf.math.cos
tf_impl[lax.cosh_p] = tf.math.cosh
tf_impl_with_avals[lax.acos_p] = _convert_jax_impl(lax.acos_translation_rule,
multiple_results=False)
tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(lax.asin_translation_rule,
multiple_results=False)
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(lax.atan_translation_rule,
multiple_results=False)
# TODO: remove this when TF supports atan2 for [b]float16
def _atan2(y: TfVal, x: TfVal) -> TfVal:
if y.dtype in (tf.bfloat16, tf.float16):
return tf.cast(tf.math.atan2(tf.cast(y, tf.float32),
tf.cast(x, tf.float32)),
y.dtype)
else:
return tf.math.atan2(y, x)
tf_impl[lax.atan2_p] = _atan2
tf_impl[lax.acosh_p] = tf.math.acosh
tf_impl[lax.atanh_p] = tf.math.atanh
tf_impl[lax.asinh_p] = tf.math.asinh
tf_impl[lax.sqrt_p] = tf.math.sqrt
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
tf_impl[lax.lgamma_p] = tf.math.lgamma
tf_impl[lax.digamma_p] = tf.math.digamma
tf_impl[lax.igamma_p] = tf.math.igamma
tf_impl[lax.igammac_p] = tf.math.igammac
tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc
tf_impl[lax.erf_p] = tf.math.erf
tf_impl[lax.erfc_p] = tf.math.erfc
tf_impl[lax.erf_inv_p] = tf.math.erfinv
tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
tf_impl[lax.complex_p] = tf.complex
def _conj(x, **kwargs):
# The only dtypes that are allowed are: float32, float64, complex64, and
# complex128.
if x.dtype == tf.float32:
return tf.cast(x, tf.complex64)
elif x.dtype == tf.float64:
return tf.cast(x, tf.complex128)
else:
return tf.math.conj(x)
tf_impl[lax.conj_p] = _conj
tf_impl[lax.real_p] = tf.math.real
tf_impl[lax.imag_p] = tf.math.imag
tf_impl[lax.add_p] = _add
tf_impl[lax.sub_p] = tf.math.subtract
tf_impl[lax.mul_p] = tf.math.multiply
def _iota(*, dtype, shape, dimension):
dtype = _to_tf_dtype(dtype)
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
# TODO(mattjj, necula): improve tf.range dtype handling
shape_tf = _eval_shape(shape)
vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32)
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype)
tf_impl[lax.iota_p] = _iota
def _div(lhs, rhs):
if lhs.dtype.is_integer:
quotient = tf.math.floordiv(lhs, rhs)
select = tf.math.logical_and(
tf.not_equal(_sign(lhs), _sign(rhs)),
tf.not_equal(tf.math.floormod(lhs, rhs), 0))
return tf.where(select, quotient + 1, quotient)
else:
return tf.math.truediv(lhs, rhs)
def _rem(lhs, rhs):
return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs))
tf_impl[lax.div_p] = _div
tf_impl[lax.rem_p] = _rem
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue,) -> TfVal:
# For complex numbers use lexicographic ordering, like JAX
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
return _convert_jax_impl(
partial(lax._minmax_complex_lowering,
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
elif x.dtype.as_numpy_dtype == np.bool_:
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
else:
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
# For reducers we will need min/max for scalars only. In that case we
# can construct the AbstractValues outselves, even in the presence of
# shape polymorphism.
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
return _minmax(x, y, is_min=is_min,
_in_avals=[aval, aval], _out_aval=aval)
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
tf.int8: tf.uint8,
tf.int16: tf.uint16,
tf.int32: tf.uint32,
tf.int64: tf.uint64,
}
# Map from TF unsigned types to TF signed types.
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}
# Note: Bitwise operations only yield identical results on unsigned integers!
# pylint: disable=protected-access
def _shift_right_arithmetic_raw(x, y):
if x.dtype.is_unsigned:
assert x.dtype == y.dtype
orig_dtype = x.dtype
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
x = tf.cast(x, signed_dtype)
y = tf.cast(y, signed_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
else:
return tf.bitwise.right_shift(x, y)
def _shift_right_arithmetic(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA
# semantics to return the shift by the max value (x_bits - 1).
# TODO: it is likely better to add XlaOps for shifts
x_bits = 8 * x.dtype.size
clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1)
return _shift_right_arithmetic_raw(x, clamp_y)
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic
def _shift_right_logical_raw(x, y):
if x.dtype.is_unsigned:
return tf.bitwise.right_shift(x, y)
else:
assert x.dtype == y.dtype
orig_dtype = x.dtype
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
x = tf.cast(x, unsigned_dtype)
y = tf.cast(y, unsigned_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
def _shift_right_logical(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(
_shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x))
tf_impl[lax.shift_right_logical_p] = _shift_right_logical
def _shift_left(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(
_shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x))
tf_impl[lax.shift_left_p] = _shift_left
def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal:
# Return the TF expression for when y is within bounds (0 <= y < |x|)
x_bits = 8 * x.dtype.size
# TF does not have comparisons for uint16 and uint32 (despite what the
# documentation says)
y_comp = tf.cast(
y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y
y_lt_x_bits = tf.math.less(y_comp, x_bits)
y_ge_0 = tf.math.greater_equal(y_comp, 0)
return tf.logical_and(y_lt_x_bits, y_ge_0)
def _not(x):
"""Computes bitwise not with support for booleans.
Numpy and JAX support bitwise not for booleans by applying a logical not!
This means that applying bitwise_not yields an unexected result:
jnp.bitwise_not(jnp.array([True, False]))
>> DeviceArray([False, True], dtype=bool)
if you assume that booleans are simply casted to integers.
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
>> DeviceArray([True, True], dtype=bool)
"""
if x.dtype == tf.bool:
return tf.logical_not(x)
else:
return tf.bitwise.invert(x)
tf_impl[lax.not_p] = _not
def bool_to_int8(f, argnums: Sequence[int]):
"""Computes functions with some bool args and bool results using int8.
This is needed because some TF ops do not work for bool args, e.g.,
inequalities, min/max.
Args:
f: a TF callable to wrap. It will be called with non-boolean arguments.
argnums: the positional arguments that may be booleans.
Returns: a TF callable that can take a mix of boolean positional arguments
(in the positions specified by `argnums`) and some non-boolean positional
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
casts the boolean arguments to `int8`, calls `f`, then casts the result to
`bool`.
"""
argnums = tf.nest.flatten(argnums)
def wrapper(*args: TfVal, **kwargs):
argnum_types = {args[i].dtype for i in argnums}
if tf.bool not in argnum_types:
return f(*args, **kwargs)
else:
# All argnums should be boolean
assert len(argnum_types) == 1, argnum_types
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
for i, a in enumerate(args)]
if "_in_avals" in kwargs:
def cast_aval(aval):
assert aval.dtype == np.bool_
return core.ShapedArray(aval.shape, np.int8)
_in_avals_cast = [
cast_aval(aval) if i in argnums else aval
for i, aval in enumerate(kwargs["_in_avals"])
]
_out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"])
kwargs = dict(
kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast)
out = f(*args_cast, **kwargs)
return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out)
return wrapper
tf_impl[lax.or_p] = bool_to_int8(tf.bitwise.bitwise_or, argnums=(0, 1))
tf_impl[lax.and_p] = bool_to_int8(tf.bitwise.bitwise_and, argnums=(0, 1))
tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal
tf_impl[lax.ge_p] = bool_to_int8(tf.math.greater_equal, argnums=(0, 1))
tf_impl[lax.gt_p] = bool_to_int8(tf.math.greater, argnums=(0, 1))
tf_impl[lax.le_p] = bool_to_int8(tf.math.less_equal, argnums=(0, 1))
tf_impl[lax.lt_p] = bool_to_int8(tf.math.less, argnums=(0, 1))
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
def _convert_element_type(operand, *, new_dtype, weak_type=False):
old_dtype = operand.dtype.as_numpy_dtype
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = tf.math.real(operand)
if (dtypes.issubdtype(old_dtype, np.floating) and
not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype(
new_dtype, np.complexfloating) or new_dtype == np.bool_)):
sign = _sign(operand)
operand = sign * tf.math.floor(sign * operand)
return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype))
tf_impl[lax.convert_element_type_p] = _convert_element_type
def _bitcast_convert_type(operand, new_dtype):
if operand.dtype == new_dtype:
return operand
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
# The below permits mirroring the behavior of JAX when maxval < minval
op_shape_tf_val = _eval_shape(_in_avals[1].shape)
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
return tf.clip_by_value(operand, minval, maxval)
tf_impl_with_avals[lax.clamp_p] = _clamp
def _concatenate(*operands, dimension):
return tf.concat(operands, axis=dimension)
tf_impl[lax.concatenate_p] = _concatenate
def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
proto = xla_data_pb2.ConvolutionDimensionNumbers()
proto.input_batch_dimension = lhs_spec[0]
proto.input_feature_dimension = lhs_spec[1]
proto.output_batch_dimension = out_spec[0]
proto.output_feature_dimension = out_spec[1]
proto.kernel_output_feature_dimension = rhs_spec[0]
proto.kernel_input_feature_dimension = rhs_spec[1]
proto.input_spatial_dimensions.extend(lhs_spec[2:])
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
proto.output_spatial_dimensions.extend(out_spec[2:])
return proto
def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
PrecisionType]]):
"""Convert an integer to an XLA.PrecisionConfig."""
if precision is None:
return None
proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(int(precision[0]))
proto.operand_precision.append(int(precision[1]))
return proto
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
preferred_element_type: Optional[DType], out_shape) -> TfVal:
def error(msg):
suffix = ("See source code for the precise conditions under which "
"convolutions can be converted without XLA.")
return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}")
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
# can be translated into TF primitives. Further investigation is needed to
# fully flesh it out.
if lhs.dtype not in [tf.float16, tf.float32, tf.float64]:
raise error(f"tf.nn.convolution is not supported for dtype {lhs.dtype}")
if feature_group_count != 1:
raise error("tf.nn.convolution does not support grouped convolutions")
# TODO(bchetioui): is there something to do with batch_group_count?
if batch_group_count != 1:
raise error("Unimplemented support for batch_group_count != 1")
nb_spatial_dimensions = len(lhs.shape) - 2
# TF can only deal with 1D, 2D and 3D convolution
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
raise error("TensorFlow can only handle convolutions with 1, 2, or 3 "
"spatial dimensions")
# TODO(bchetioui): handle different stride cases
if list(window_strides) != [1] * nb_spatial_dimensions:
raise error("Unimplemented support for window_strides != "
f"{tuple([1] * nb_spatial_dimensions)}")
if preferred_element_type is not None and preferred_element_type != lhs.dtype:
raise error("Unimplemented support for preferred_element_type")
def convert_padding() -> str:
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
# string padding is not implemented for transposed convolution.
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
raise error("Padding conversion is not supported for transposed "
"convolution.")
lhs_perm, rhs_perm, _ = dimension_numbers
effective_rhs_shape = [
(k - 1) * r + 1
for k, r in zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)
]
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
# TF only allows 'VALID' and 'SAME' padding
for pad_str in ["VALID", "SAME"]:
gen_padding = lax.padtype_to_pads(lhs_shape, effective_rhs_shape,
window_strides, pad_str)
if list(gen_padding) == list(padding):
return pad_str
raise error("Input padding not supported in TensorFlow.")
def convert_dim_nums() -> str:
lhs_spec, rhs_spec, out_spec = dimension_numbers
# TF only allows filters with shape:
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
# rhs_spec is represented as a tuple containing the following:
# [out_channels, in_channels] + spatial_filter_shape.
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
list(range(nb_spatial_dimensions)))
if list(rhs_spec) != supported_rhs_shape:
raise error("Input filter (RHS) shape format not supported in "
"TensorFlow.")
# TF only supports same LHS and output data format
if lhs_spec != out_spec:
raise error("TensorFlow requires the same data format for LHS and "
"output.")
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
spatial_dim_alphabet = "DHW"[-nb_spatial_dimensions:]
# TF only supports the following data formats:
# - [batch_size, in_channels] + input_spatial_shape
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
# failing on this platform, this path is commented out for now.
# if list(lhs_spec) == list(range(len(lhs_spec))):
# return "NC" + spatial_dim_alphabet
# - [batch_size] + input_spatial_shape + [in_channels]
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
list(range(1,
len(lhs_spec) - 1))):
return "N" + spatial_dim_alphabet + "C"
raise error("Data format is unsupported by TensorFlow.")
def convert_dilation_and_compute_result(tf_padding: str,
tf_dim_nums: str) -> TfVal:
no_dilation = [1] * nb_spatial_dimensions
# TODO(bchetioui): is there a generic way to do a transposed atrous
# convolution in TensorFlow?
if not (list(lhs_dilation) == no_dilation or
list(rhs_dilation) == no_dilation):
raise error("Both LHS and RHS dilations are set.")
# This is a non-dilated or atrous convolution
if list(lhs_dilation) == no_dilation:
return tf.nn.convolution(
lhs,
rhs,
strides=window_strides,
padding=tf_padding,
data_format=tf_dim_nums,
dilations=rhs_dilation)
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
# dilation to this function will result in convert_padding returning None
# systematically. This must be investigated further.
# Dilation of the LHS is transposed convolution
return tf.nn.conv_transpose(
lhs,
rhs,
out_shape,
window_strides,
padding=tf_padding,
data_format=tf_dim_nums,
dilations=lhs_dilation)
tf_padding = convert_padding()
tf_dim_nums = convert_dim_nums()
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
def _conv_general_dilated(lhs, rhs, *,
window_strides, padding, lhs_dilation,
rhs_dilation,
dimension_numbers: lax.ConvDimensionNumbers,
feature_group_count: int,
batch_group_count: int,
lhs_shape: Sequence[int],
rhs_shape: Sequence[int],
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
out_tf_shape = _aval_to_tf_shape(_out_aval)
if not _thread_local_state.enable_xla:
return _try_tf_conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
preferred_element_type, out_tf_shape)
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _precision_config_proto(precision)
assert batch_group_count == 1 # TODO(necula): implement batch_group_count
def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]):
out = tfxla.conv(
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dnums_proto,
feature_group_count=feature_group_count,
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type)
# TODO: implement shape inference for XlaConv
out.set_shape(out_tf_shape)
return out
# Follow the lowering for complex convolutions from
# lax._conv_general_dilated_translation. We can use the same conversion on all
# platforms because on XLA:TPU the compiler does the same as a rewrite.
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
assert np.issubdtype(preferred_element_type, np.complexfloating)
preferred_float_et = (
np.float64 if preferred_element_type == np.complex128 else np.float32)
else:
preferred_float_et = None
lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs)
rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs)
k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et)
k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real),
preferred_float_et)
k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et)
return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2))
else:
return gen_conv(lhs, rhs, preferred_element_type)
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
if _thread_local_state.enable_xla:
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
res = tfxla.dot_general(
lhs,
rhs,
dnums_proto,
precision_config_proto,
preferred_element_type=preferred_element_type)
# TODO: in presence of None dimensions, XlaDot shape inference returns
# unknown shape.
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
# This condition ensures that:
# 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
# not strictly necessary, but we would have to reshape the array if that
# were not the case;
# 2) lhs and rhs have the same number of dimensions +/- 1
# 3) the number of non-batch dimensions in both tensors is either 1 or 2
# 4) the contracting dimensions are consistent with those of a classic
# matrix/matrix, vector/matrix or matrix/vector multiplication.
if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) and
lhs_ndim - rhs_ndim in [-1, 0, 1] and
1 <= lhs_ndim - len(lhs_batch) <= 2 and
1 <= rhs_ndim - len(rhs_batch) <= 2 and
lhs_contracting == (len(lhs.shape) - 1,) and
rhs_contracting == (len(lhs_batch),)):
# All the inputs to tf.linalg.matmul must have 2 inner dimensions,
# after their batch dimensions, so we need to expand the dimensions
# appropriately. We can get to this branch with three combinations of
# inner shapes:
# - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
# - in this case, the resulting inner shape is [a, c];
# - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
# - in this case, we need to expand lhs to [1, b], and the resulting
# shape is [c]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [1, c];
# - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
# - in this case, we need to expand rhs to [b, 1], and the resulting
# shape is [a]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [a, 1];
# - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
# - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
# and the resulting shape is (). We need to squeeze the result of
# tf.linalg.matmul as it will have shape [1, 1].
squeeze_idxs = []
if lhs_ndim - len(lhs_batch) == 1:
lhs = tf.expand_dims(lhs, lhs_ndim - 1)
squeeze_idxs.append(len(lhs.shape) - 2)
if rhs_ndim - len(rhs_batch) == 1:
rhs = tf.expand_dims(rhs, rhs_ndim)
squeeze_idxs.append(len(rhs.shape) - 1)
result = tf.linalg.matmul(lhs, rhs)
if len(squeeze_idxs) != 0:
assert all([result.shape[i] == 1 for i in squeeze_idxs])
result = tf.squeeze(result, squeeze_idxs)
return result
new_id = iter(string.ascii_letters)
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
lhs_out_axis_ids = lhs_axis_ids[:]
rhs_out_axis_ids = rhs_axis_ids[:]
for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids = []
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids.append(shared_id)
not_none = lambda x: x is not None
out_axis_ids = list(
filter(not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids))
assert lhs.dtype == rhs.dtype
spec = "{},{}->{}".format("".join(lhs_axis_ids), "".join(rhs_axis_ids),
"".join(out_axis_ids))
return tf.linalg.einsum(spec, lhs, rhs)
tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _broadcast(operand, *, sizes):
result_shape = tf.TensorShape(sizes).concatenate(operand.shape)
return tf.broadcast_to(operand, result_shape)
tf_impl[lax.broadcast_p] = _broadcast
def _broadcast_in_dim(operand, *, shape, broadcast_dimensions):
inshape = [1] * len(shape)
for orig_shape_i, broadcast_dim_i in zip(operand.shape, broadcast_dimensions):
if orig_shape_i != 1:
inshape[broadcast_dim_i] = shape[broadcast_dim_i]
inshape_tf = _eval_shape(inshape)
shape_tf = _eval_shape(shape)
return tf.broadcast_to(tf.reshape(operand, inshape_tf), shape_tf)
tf_impl[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _reshape(operand, *, new_sizes, dimensions):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))
new_sizes_tf = _eval_shape(new_sizes)
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
tf_impl[lax.reshape_p] = _reshape
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
op_shape = _in_avals[0].shape
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
new_shape_tf = _eval_shape(new_shape)
return tf.reshape(operand, new_shape_tf)
tf_impl_with_avals[lax.squeeze_p] = _squeeze
def _pad(operand, padding_value, *, padding_config,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
del _in_avals
low, high, interior = util.unzip3(padding_config)
if _thread_local_state.enable_xla:
out = tfxla.pad(operand, padding_value, low, high, interior)
return out
if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
return tf.pad(
operand,
zip(low, high),
mode="CONSTANT",
constant_values=padding_value)
raise _xla_disabled_error("pad", "Only use cases without interior or negative padding can be converted without XLA.")
tf_impl_with_avals[lax.pad_p] = _pad
def _rev(operand, *, dimensions):
return tf.reverse(operand, dimensions)
tf_impl[lax.rev_p] = _rev
tf_impl[lax.select_p] = tf.where
def _transpose(operand, *, permutation):
return tf.transpose(operand, perm=permutation)
tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
tf_impl[lax.reduce_sum_p] = (
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
tf_impl[lax.reduce_prod_p] = (
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
tf_impl[lax.reduce_max_p] = (
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
tf_impl[lax.reduce_min_p] = (
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=[0]))
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
def _argminmax(fn, operand, axes, index_dtype):
axis, = axes
output_type = tf.int32
if dtypes.iinfo(index_dtype).bits > 32:
output_type = tf.int64
# TODO(phawkins): handle axes larger than 2^31.
result = fn(operand, axis=axis, output_type=output_type)
return tf.cast(result, _to_tf_dtype(index_dtype))
tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin)
tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax)
_add_fn = tf.function(_add, autograph=False)
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
def _select_and_gather_add(
tangents: TfVal, operand: TfVal, select_prim: core.Primitive,
window_dimensions: Sequence[int], window_strides: Sequence[int],
base_dilation: Sequence[int], window_dilation: Sequence[int],
padding: Sequence[Tuple[int, int]], _in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
# Note: this function follows the pattern in
# jax.lax._select_and_gather_add_translation.
dtype = operand.dtype
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits
# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
# we thus intend to let the code throw a different exception on this platform.
max_bits = 64
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda dtype, x: tf.constant(np.array(x), dtype)
if double_word_reduction:
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
# Packs two values into a tuple.
def pack(a, b):
a = _bitcast_convert_type(a, word_dtype)
b = _bitcast_convert_type(b, word_dtype)
a = _convert_element_type(a, new_dtype=double_word_dtype)
b = _convert_element_type(b, new_dtype=double_word_dtype)
a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits))
return tf.bitwise.bitwise_or(a, b)
# Unpacks the first element of a tuple.
def fst(t):
assert t.dtype == double_word_dtype
st = _shift_right_logical(t, const(double_word_dtype, nbits))
return _bitcast_convert_type(
_convert_element_type(st, new_dtype=word_dtype), dtype)
# Unpacks the second element of a tuple.
def snd(t):
return _bitcast_convert_type(
_convert_element_type(t, new_dtype=word_dtype), dtype)
else:
raise NotImplementedError(
f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits."
)
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
def reducer(x, y):
which = tf_impl[select_prim]
return tf_impl[lax.select_p](which(fst(x), fst(y)), x=x, y=y)
init = -np.inf if select_prim is lax.ge_p else np.inf
init_identity = lambda x: pack(const(dtype, init), const(dtype, 0))
out = _specialized_reduce_window(
reducer,
init_identity,
pack(operand, tangents),
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
_in_avals=_in_avals,
_out_aval=_out_aval)
return snd(out)
tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
def _get_shape_from_tensor_or_array(x):
if isinstance(x.shape, tf.TensorShape):
return tuple(x.shape.as_list())
return tuple(x.shape)
def _common_reduce_window(operand, init_val, reducer, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval):
o_spec = tf.TensorSpec((), dtype=operand.dtype)
reducer_fn = tf.function(
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
if not isinstance(init_val, tf.Tensor):
init_val = tf.constant(init_val, operand.dtype)
out = tfxla.reduce_window(
operand,
init_val,
reducer_fn,
window_dimensions,
window_strides,
base_dilations=base_dilation,
window_dilations=window_dilation,
padding=padding)
# TODO: implement shape inference for XlaReduceWindow
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation,
_in_avals, _out_aval):
"""TensorFlow implementation of reduce_window.
Args:
operand: N dimensional array containing elements of type T
init_value: starting value of the reduction
jaxpr: the jaxpr corresponding to the reduction function
consts: the constants associated with jaxpr.
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
Returns:
The reduced operand.
"""
assert len(consts) == 0, "Reduction computation cannot have constants"
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("reduce_window")
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
return res
return _common_reduce_window(operand, init_value, reducer, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval)
# _try_tf_pool currently only supports reduce_window_max and reduce_window_sum.
# TODO(bchetioui): this function is not exhaustive wrt which
# reduce_window_max or reduce_window_sum cases can be translated into a call to
# max_pool or avg_pool. Further investigation is needed to fully flesh it out.
def _try_tf_pool(op_name, operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation) -> TfVal:
def error(msg):
suffix = ("See source code for the precise conditions under which "
"reduce_window can be converted without XLA.")
return _xla_disabled_error("reduce_window", f"{msg} - {suffix}")
dtype = operand.dtype
# Contrarily to the main path, tf.int8 is actually a valid type for
# tf.nn.max_pool.
if op_name == "reduce_window_max" and dtype in [
tf.bool, tf.uint32, tf.uint64, tf.complex64, tf.complex128
]:
raise error(f"tf.nn.max_pool does not support operands of type {dtype}")
if op_name == "reduce_window_sum" and operand.dtype not in [
tf.float16, tf.float32, tf.float64
]:
raise error(f"tf.nn.avg_pool does not support operands of type {dtype}")
has_batch_dim = window_dimensions[0] == 1
has_channel_dim = window_dimensions[-1] == 1
nb_spatial_dimensions = len(operand.shape) - has_batch_dim - has_channel_dim
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
raise error("TensorFlow can only handle pooling for arrays with 1, 2, or "
"3 spatial dimensions")
# TODO(bchetioui): does a simple conversion with another base dilation exist?
if list(base_dilation) != [1] * len(operand.shape):
raise error("Unimplemented support for base dilation")
# TODO(bchetioui): does a simple conversion with another window_dilation
# exist? The whole story seems similar to convolution.
if list(window_dilation) != [1] * len(operand.shape):
raise error("Unimplemented support for window dilation")
if list(padding) != [(0, 0)] * len(operand.shape):
raise error("Unimplemented support for padding")
# ReduceWindow in XLA takes an array of rank N as a parameter, but
# tf.nn.max_pool / tf.nn.avg_pool take an array of rank N+2, with a default
# shape of the form [batch_size] + input_spatial_shape + [num_channels]
tf_operand = operand
tf_window_dimensions = list(window_dimensions)
tf_window_strides = list(window_strides)
if not has_batch_dim:
tf_operand = tf.expand_dims(tf_operand, 0)
tf_window_dimensions = [1] + tf_window_dimensions
tf_window_strides = [1] + tf_window_strides
if not has_channel_dim:
tf_operand = tf.expand_dims(tf_operand, -1)
tf_window_dimensions.append(1)
tf_window_strides.append(1)
tf_data_format = "N" + "DHW"[-nb_spatial_dimensions:] + "C"
tf_padding = "VALID"
if op_name == "reduce_window_max":
result = tf.nn.max_pool(tf_operand, tf_window_dimensions, tf_window_strides,
tf_padding, tf_data_format)
elif op_name == "reduce_window_sum":
avg = tf.nn.avg_pool(tf_operand, tf_window_dimensions, tf_window_strides,
tf_padding, tf_data_format)
result = avg * np.prod(tf_window_dimensions)
else:
raise error(f"Unimplemented support for {op_name}")
if not has_batch_dim:
result = tf.squeeze(result, 0)
if not has_channel_dim:
result = tf.squeeze(result, -1)
return result
def _specialized_reduce_window(reducer,
identity,
operand,
*,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
_in_avals,
_out_aval,
name=None):
"""Wraps the TensorFlow reduce window operation based on a reducer and an
identity function defining the initial value of the reduction depending on
the dtype of the operand.
Args:
reducer: reduction function of type TfVal -> TfVal -> TfVal
identity: function that takes a TensorFlow dtype as a parameter and returns
the starting value of the reduction.
operand: N dimensional array containing elements of type T
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
name: the name of the specialized reduce window primitive for which this
conversion function is called. This information may help to choose a
different conversion path (optional)
Returns:
The reduced operand.
"""
if not _thread_local_state.enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
return _try_tf_pool(name, operand, window_dimensions, window_strides,
padding, base_dilation, window_dilation)
return _common_reduce_window(operand, identity(operand.dtype), reducer,
window_dimensions, window_strides, padding,
base_dilation, window_dilation, _in_avals,
_out_aval)
def _get_max_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(-np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).min
else:
assert dtypes.issubdtype(
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity")
return False
def _get_min_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).max
else:
assert dtypes.issubdtype(
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity")
return True
# pylint: disable=protected-access
tf_impl_with_avals[lax.reduce_window_sum_p] = (
partial(_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
tf_impl_with_avals[lax.reduce_window_min_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=True),
_get_min_identity,
name="reduce_window_min"))
tf_impl_with_avals[lax.reduce_window_max_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=False),
_get_max_identity,
name="reduce_window_max"))
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# pylint: enable=protected-access
# We use lax_control_flow._cumred_tpu_translation_rule to convert cummax,
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min),
multiple_results=False,
extra_name_stack="cummin")
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max),
multiple_results=False,
extra_name_stack="cummin")
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
# certain dtypes: bfloat16, float16, float32, float64, and int32. Other dtypes
# will fail when running in compiled mode, but are otherwise compatible with
# the operation. A non-XLA path can thus be defined for all dtypes, though the
# tests will crash.
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum),
multiple_results=False,
extra_name_stack="cumsum")
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod),
multiple_results=False,
extra_name_stack="cumprod")
def _select_and_scatter(operand, source, init_value, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts,
window_dimensions, window_strides, padding):
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
@partial(bool_to_int8, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("select_and_scatter_add")
init_value = tf.zeros((), operand.dtype)
select_fn = (
tf.function(tf_impl[select_prim], autograph=False).get_concrete_function(
init_value, init_value))
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
padding, source, init_value, select_fn,
scatter_fn)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
res = _convert_jax_impl(
partial(jax._src.random._threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True, extra_name_stack="threefry")(
*args, _in_avals=_in_avals, _out_aval=_out_aval)
return res
tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
# Use the vmap implementation, otherwise on TPU the performance is really bad
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
partial(jax._src.random._gamma_impl, use_vmap=True),
multiple_results=False, extra_name_stack="random_gamma")
def _gather_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.GatherDimensionNumbers()
proto.offset_dims.extend(dimension_numbers.offset_dims)
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
proto.start_index_map.extend(dimension_numbers.start_index_map)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
@partial(bool_to_int8, argnums=[0])
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
indices_are_sorted, unique_indices,
_in_avals, _out_aval):
"""Tensorflow implementation of gather."""
del _in_avals, unique_indices
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("gather")
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes)
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
indices_are_sorted)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.gather_p] = _gather
def _slice(operand, start_indices, limit_indices, strides, _in_avals,
_out_aval):
if strides is None:
strides = [1] * len(start_indices)
slices = tuple(map(slice,
_eval_shape(start_indices),
_eval_shape(limit_indices),
_eval_shape(strides)))
out = operand[slices]
# TODO(b/184503314): improve shape inference for __getitem__
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.slice_p] = _slice
def _dynamic_slice(operand, *start_indices, slice_sizes,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
start_indices = tf.stack(start_indices)
slice_sizes = _eval_shape(slice_sizes)
if _thread_local_state.enable_xla:
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes)
# TODO: implement shape inference for XlaDynamicSlice
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
# If XLA is disabled, we use `tf.slice` as a fallback, which has different out
# of bounds (OOB) behavior than `lax.dynamic_slice_p`:
# * If `slice size > max_len`, then both `tf.slice` and `lax.dynamic_slice_p`
# fail, so we ignore this case.
# * If `start_indices` is OOB, then `tf.slice` fails but `lax.dynamic_slice_p`
# clips the indices to [0, max_len].
# * If `start_indices + slice_size` is OOB, then `tf.slice` fails, but
# `lax.dynamic_slice_p` adjust `start_indices` so that a full slice is
# returned.
# The code below manually clips the start indices so that the behavior is
# the same as `lax.dynamic_slice_p`.
operand_shape = _eval_shape(_in_avals[0].shape)
max_start = tf.subtract(operand_shape, slice_sizes)
# If `operand_shape` and `slice_sizes` are Python tuples of integers,
# `tf.subtract` returns a Tensor of dtype tf.int32, which may conflict with
# the dtype of `start_indices` if we run in x64 mode and throw an error when
# calling `tf.clip_by_vaue`. Therefore we cast to the right dtype here
# explicitly.
max_start = tf.cast(max_start, dtype=start_indices.dtype)
start_indices = tf.clip_by_value(start_indices, 0, max_start)
return tf.slice(operand, start_indices, size=slice_sizes)
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.ScatterDimensionNumbers()
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
proto.scatter_dims_to_operand_dims.extend(
dimension_numbers.scatter_dims_to_operand_dims)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
del unique_indices, _in_avals
assert len(update_consts) == 0, "Update computation cannot have constants"
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("scatter")
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
return res
o_spec = tf.TensorSpec((), dtype=operand.dtype)
xla_update_computation = (
tf.function(update_computation,
autograph=False).get_concrete_function(o_spec, o_spec))
out = tfxla.scatter(
operand,
scatter_indices,
updates,
xla_update_computation,
proto,
indices_are_sorted=indices_are_sorted)
# TODO: implement shape analysis for XlaScatter
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.scatter_p] = _scatter
tf_impl_with_avals[lax.scatter_min_p] = _scatter
tf_impl_with_avals[lax.scatter_max_p] = _scatter
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
tf_impl_with_avals[lax.scatter_add_p] = _scatter
def _dynamic_update_slice(operand, update, *start_indices):
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("dynamic_update_slice")
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
tf_impl[lax.dynamic_update_slice_p] = _dynamic_update_slice
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
linear: Sequence[bool]) -> Sequence[TfVal]:
del linear
# tf.cond needs lambdas with no arguments.
branches_tf = [
partial(_interpret_jaxpr, jaxpr, *operands,
# Same name stack as the XLA translation of cond_p
extra_name_stack=f"branch_{i}_fun")
for jaxpr in branches
for i, jaxpr in enumerate(branches)
]
return tf.switch_case(index, branches_tf)
tf_impl[lax_control_flow.cond_p] = _cond
def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
cond_consts, body_consts, init_carry = util.split_list(
args, [cond_nconsts, body_nconsts])
if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined]
# The conditional is not a scalar, this must be a batched while
return _batched_cond_while(
*args,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr)
# The conditional must return a single value to TF
def cond_tf_func(*args: TfVal) -> TfVal:
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args,
# Same name stack as the XLA translation of while_p
extra_name_stack="while/cond")
return pred
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts,
extra_name_stack="while/body")
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
def _batched_cond_while(*args: TfVal, cond_nconsts: int,
cond_jaxpr: core.ClosedJaxpr, body_nconsts: int,
body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
"""Interprets a while_loop with a batched condition.
A batched while has a conditional that returns a tensor of booleans, and
a body that returns a list of tensors whose leading dimensions match those
of the conditional tensor.
We need to turn it into a while with scalar boolean conditional. We will
expand the loop carry to include a prefix with the current tensor boolean
condition. We prepend to the loop the first calculation of the tensor boolean
condition. The loop condition will use a "reduce_any" to calculate a scalar
boolean from the tensor boolean condition. The end of the loop body will
compute the new carry using a "tf.where", and we compute the new tensor
boolean condition.
"""
cond_consts, body_consts, init_carry = util.split_list(
args, [cond_nconsts, body_nconsts])
# Initial computation of batched condition
init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry,
extra_name_stack="while/body_pred")
assert init_pred_b is not core.unit
def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal:
pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape))))
return pred
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts,
*carry,
extra_name_stack="while/body")
def select_one_carry(new_c: TfVal, c: TfVal) -> TfVal:
pred_b_bcast = _broadcast_in_dim(
pred_b,
shape=new_c.shape,
broadcast_dimensions=list(range(len(pred_b.shape))))
return tf.where(pred_b_bcast, new_c, c)
selected_carry: Sequence[TfVal] = list(
map(select_one_carry, new_carry, carry))
next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry,
extra_name_stack="body_pred")
return (next_pred_b, *selected_carry)
_, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func,
(init_pred_b, *init_carry))
return res_carry
tf_impl[lax_control_flow.while_p] = _while
# We use the scan impl rule to rewrite in terms of while.
tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(
lax_control_flow._scan_impl,
extra_name_stack="scan")
def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
# Some types originally incompatible with tf.math.top_k can be promoted
# to a compatible type without loss of precision.
def promote_tf_dtype(tf_dtype):
if tf_dtype in [tf.bool, tf.uint8, tf.uint16]:
return tf.uint32
if tf_dtype in [tf.int8, tf.int16]:
return tf.int32
if tf_dtype is tf.float16:
return tf.float32
return None
conversion_dtype = promote_tf_dtype(operand.dtype)
if conversion_dtype:
values, indices = tf.math.top_k(
tf.dtypes.cast(operand, conversion_dtype), k=k, sorted=True)
return tf.dtypes.cast(values, operand.dtype), indices
else:
return tf.math.top_k(operand, k=k, sorted=True)
tf_impl[lax.top_k_p] = _top_k
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
num_keys: int) -> Tuple[TfVal, ...]:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("sort")
assert 1 <= num_keys <= len(operands)
assert 0 <= dimension < len(
operands[0].shape
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
comparator_spec: List[tf.TensorSpec] = []
comparator_jax_in_avals: List[core.AbstractValue] = []
for op in operands:
o_spec = tf.TensorSpec((), dtype=op.dtype)
comparator_spec.extend([o_spec, o_spec])
o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype))
comparator_jax_in_avals.extend([o_aval, o_aval])
# Use the same comparator that JAX uses when compiling to XLA, to get the
# proper NaN/Inf total order, and the lexicographic ordering.
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator(*tf_args: TfVal) -> TfVal:
return _convert_jax_impl(
lax._sort_lt_comparator, multiple_results=False)(
*tf_args,
_in_avals=comparator_jax_in_avals,
_out_aval=core.ShapedArray((), np.bool_),
num_keys=num_keys)
xla_comparator_computation = (
tf.function(lexicographic_comparator,
autograph=False).get_concrete_function(*comparator_spec))
results = tfxla.variadic_sort(
operands,
dimension=dimension,
is_stable=is_stable,
comparator=xla_comparator_computation)
return results
tf_impl[lax.sort_p] = _sort
def _fft(x, fft_type, fft_lengths):
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
if fft_type == IRFFT:
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
else:
expected_lengths = x.shape[-len(fft_lengths):]
if expected_lengths != fft_lengths:
raise NotImplementedError(
f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of "
f"array with shape={x.shape}.")
tf_funcs = {
FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
}
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
tf_impl[lax_fft.fft_p] = _fft
def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)
tf_impl[lax_linalg.qr_p] = _qr
def _svd(operand, full_matrices, compute_uv):
result = tf.linalg.svd(operand, full_matrices, compute_uv)
if not compute_uv:
return result,
s, u, v = result
return s, u, tf.linalg.adjoint(v)
tf_impl[lax_linalg.svd_p] = _svd
def _eig(operand: TfVal, compute_left_eigenvectors: bool,
compute_right_eigenvectors: bool):
if compute_left_eigenvectors and compute_right_eigenvectors:
# TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to
# sort the left eigenvectors in the right order. The jax.numpy.linalg API
# suggests to me that left eigenvectors are anyway seldom used, so I
# think it is acceptable to leave as unimplemented for now.
msg = ("Conversion of eig is not implemented when both "
"compute_left_eigenvectors and compute_right_eigenvectors are set "
"to True.")
raise NotImplementedError(msg)
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
return tuple([tf.linalg.eigvals(operand)])
elif compute_right_eigenvectors:
return tuple(tf.linalg.eig(operand))
else: # compute_left_eigenvectors == True
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
wHH = tf.math.conj(wH)
return tuple([wHH, vl])
tf_impl[lax_linalg.eig_p] = _eig
def _eigh(operand: TfVal, lower: bool, _in_avals, _out_aval):
if operand.shape[-1] == 0:
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
else:
if not lower:
operand = tf.linalg.adjoint(operand)
w, v = tf.linalg.eigh(operand)
cast_type = {
tf.complex64: tf.float32,
tf.complex128: tf.float64
}.get(operand.dtype)
if cast_type is not None:
w = tf.cast(w, cast_type)
return v, w
tf_impl_with_avals[lax_linalg.eigh_p] = _eigh
def _lu(operand: TfVal, _in_avals, _out_aval):
return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")(
operand, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[lax_linalg.lu_p] = _lu
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
transpose_a: bool, conjugate_a: bool, unit_diagonal: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
if unit_diagonal:
a_aval, _ = _in_avals
a_shape = _eval_shape(a_aval.shape)
a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype))
if not left_side:
rank = len(a.shape)
transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2]
a = tf.transpose(a, transpose_dimensions)
b = tf.transpose(b, transpose_dimensions)
lower = not lower
# adjoint == transpose for real dtypes, so special care need only be taken
# for complex types.
if a.dtype in [tf.complex64, tf.complex128]:
if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a):
a = tf.math.conj(a)
result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a)
if not left_side:
result = tf.transpose(result, transpose_dimensions)
return result
tf_impl_with_avals[lax_linalg.triangular_solve_p] = _triangular_solve
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl,
extra_name_stack="linear_solve")(
*args,
const_lengths=const_lengths,
jaxprs=jaxprs,
_in_avals=_in_avals,
_out_aval=_out_aval)
tf_impl_with_avals[lax_control_flow.linear_solve_p] = _linear_solve
def _custom_jvp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_jvp")
tf_impl[custom_derivatives.custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr
def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
**_) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp")
tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr
def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
"function.")
tf_impl[ad.custom_lin_p] = _custom_lin
def split_to_logical_devices(tensor: TfVal,
partition_dimensions: pxla.PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object
through the generated computation. It seems that the original function needs
the strategy object only for error checking, which we assume is done upstream
by JAX.
Args:
tensor: Input tensor to annotate.
partition_dimensions: A list of integers, with one integer per tensor
dimension, specifying in how many parts the dimension should be split. The
product of integers must equal the number of devices per replica.
use_sharding_op: whether to use a sharding op, or not.
Returns:
an annotated tensor.
"""
# TODO: this is only for sharded_jit. Either remove, or implement in terms
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = np.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(mesh: maps.Mesh,
val: TfVal,
aval: core.AbstractValue,
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
"""Apply sharding to a TfVal."""
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
aval, axis_resources, mesh)
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_proto: xla_data_pb2.OpSharding = (
xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim))
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
val, use_sharding_op=True)
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
in_parts: Sequence[pxla.PartitionsOrReplicated],
out_parts_thunk,
**_) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
sharded_vals = map(split_to_logical_devices, vals, in_parts)
vals_out = f.call_wrapped(*sharded_vals) # caller handles new_sublevel
out_parts_flat = out_parts_thunk()
assert len(out_parts_flat) == len(
vals_out), f"expected {len(out_parts_flat)} == {len(vals_out)}"
sharded_vals_out = [
(split_to_logical_devices(val, val_part), val_aval)
for (val, val_aval), val_part in zip(vals_out, out_parts_flat)
]
return sharded_vals_out
def _sharded_jit_sharding_constraint(arg: TfVal, *,
partitions: pxla.PartitionsOrReplicated,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
del _in_avals, _out_aval
return split_to_logical_devices(arg, partitions)
tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint
def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> TfVal:
del donated_invars
# TODO: add `name` to the name stack
shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh)
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
results = _interpret_jaxpr(jaxpr, *sharded_args,
extra_name_stack=util.wrap_name(name, "pjit"))
sharded_results: Sequence[TfVal] = tuple(
map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
return tuple(sharded_results)
tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
axis_resources: pjit.ParsedPartitionSpec,
resource_env: maps.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()
# The types here are automagically changed by TensorFlow's checkpointing
# infrastructure.
m.a = (tf.Module(), tf.Module())
m.b = [tf.Module(), tf.Module()]
m.c = {"a": tf.Module()}
tuple_wrapper = type(m.a)
list_wrapper = type(m.b)
dict_wrapper = type(m.c)
# TF AutoTrackable swaps container types out for wrappers.
assert tuple_wrapper is not tuple
assert list_wrapper is not list
assert dict_wrapper is not dict
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
(tuple(xs), None), lambda _, xs: tuple(xs))
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
lambda _, xs: list(xs))
jax.tree_util.register_pytree_node(
dict_wrapper,
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict(zip(k, xs)))
_register_checkpoint_pytrees()