mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 10:06:07 +00:00
2365 lines
95 KiB
Python
2365 lines
95 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import __builtin__
|
|
|
|
import collections
|
|
from .util import partial
|
|
import itertools
|
|
import operator
|
|
import string
|
|
|
|
import numpy as onp
|
|
|
|
from . import core
|
|
from . import ad_util
|
|
from . import linear_util as lu
|
|
from .core import Primitive
|
|
from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
|
array_types, make_shaped_array)
|
|
from .api_util import flatten_fun, tree_to_jaxtuples
|
|
from .interpreters import partial_eval as pe
|
|
from .interpreters import xla
|
|
from .interpreters import ad
|
|
from .interpreters import batching
|
|
from .util import curry, safe_zip, unzip2
|
|
from .tree_util import build_tree
|
|
from .lib import xla_bridge
|
|
|
|
_max = __builtin__.max
|
|
_min = __builtin__.max
|
|
|
|
|
|
### traceables
|
|
|
|
def neg(x): return neg_p.bind(x)
|
|
def sign(x): return sign_p.bind(x)
|
|
def floor(x): return floor_p.bind(x)
|
|
def ceil(x): return ceil_p.bind(x)
|
|
def round(x): return round_p.bind(x)
|
|
|
|
def is_finite(x): return is_finite_p.bind(x)
|
|
|
|
def exp(x): return exp_p.bind(x)
|
|
def expm1(x): return expm1_p.bind(x)
|
|
def log(x): return log_p.bind(x)
|
|
def log1p(x): return log1p_p.bind(x)
|
|
def tanh(x): return tanh_p.bind(x)
|
|
def sin(x): return sin_p.bind(x)
|
|
def cos(x): return cos_p.bind(x)
|
|
def atan2(x, y): return atan2_p.bind(x, y)
|
|
|
|
def lgamma(x): return lgamma_p.bind(x)
|
|
def digamma(x): return digamma_p.bind(x)
|
|
def erf(x): return erf_p.bind(x)
|
|
def erfc(x): return erfc_p.bind(x)
|
|
def erf_inv(x): return erf_inv_p.bind(x)
|
|
|
|
def real(x): return real_p.bind(x)
|
|
def imag(x): return imag_p.bind(x)
|
|
def complex(x, y): return complex_p.bind(_brcast(x, y), _brcast(y, x))
|
|
def conj(x): return conj_p.bind(x)
|
|
def abs(x): return abs_p.bind(x)
|
|
def pow(x, y): return pow_p.bind(x, y)
|
|
|
|
def bitwise_not(x): return not_p.bind(x)
|
|
def bitwise_and(x, y): return and_p.bind(x, y)
|
|
def bitwise_or(x, y): return or_p.bind(x, y)
|
|
def bitwise_xor(x, y): return xor_p.bind(x, y)
|
|
|
|
def add(x, y): return add_p.bind(x, y)
|
|
def sub(x, y): return sub_p.bind(x, y)
|
|
def mul(x, y): return mul_p.bind(x, y)
|
|
def div(x, y): return div_p.bind(x, y)
|
|
def rem(x, y): return rem_p.bind(x, y)
|
|
|
|
def max(x, y): return max_p.bind(x, y)
|
|
def min(x, y): return min_p.bind(x, y)
|
|
|
|
def shift_left(x, y): return shift_left_p.bind(x, y)
|
|
def shift_right_arithmetic(x, y): return shift_right_arithmetic_p.bind(x, y)
|
|
def shift_right_logical(x, y): return shift_right_logical_p.bind(x, y)
|
|
|
|
def eq(x, y): return eq_p.bind(x, y)
|
|
def ne(x, y): return ne_p.bind(x, y)
|
|
def ge(x, y): return ge_p.bind(x, y)
|
|
def gt(x, y): return gt_p.bind(x, y)
|
|
def le(x, y): return le_p.bind(x, y)
|
|
def lt(x, y): return lt_p.bind(x, y)
|
|
|
|
def convert_element_type(operand, new_dtype):
|
|
new_dtype = xla_bridge.canonicalize_dtype(new_dtype)
|
|
old_dtype = _dtype(operand)
|
|
if old_dtype != new_dtype:
|
|
return convert_element_type_p.bind(
|
|
operand, new_dtype=new_dtype, old_dtype=old_dtype)
|
|
else:
|
|
return operand
|
|
|
|
def bitcast_convert_type(operand, new_dtype):
|
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
|
|
|
def clamp(min, operand, max):
|
|
return clamp_p.bind(min, operand, max)
|
|
|
|
def concatenate(operands, dimension):
|
|
return concatenate_p.bind(*operands, dimension=dimension,
|
|
operand_shapes=tuple(o.shape for o in operands))
|
|
|
|
def conv(lhs, rhs, window_strides, padding):
|
|
pads = padtype_to_pads(lhs.shape[2:], rhs.shape[2:], window_strides, padding)
|
|
return conv_general_dilated_p.bind(
|
|
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(pads),
|
|
lhs_dilation=(), rhs_dilation=(), dimension_numbers=None,
|
|
lhs_shape=lhs.shape, rhs_shape=rhs.shape)
|
|
|
|
def conv_with_general_padding(lhs, rhs, window_strides, padding,
|
|
lhs_dilation, rhs_dilation):
|
|
return conv_general_dilated_p.bind(
|
|
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
|
lhs_dilation=(), rhs_dilation=(), dimension_numbers=None,
|
|
lhs_shape=lhs.shape, rhs_shape=rhs.shape)
|
|
|
|
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
|
rhs_dilation, dimension_numbers):
|
|
if isinstance(padding, str):
|
|
perms = conv_general_permutations(dimension_numbers)
|
|
lhs_perm, rhs_perm, _ = perms
|
|
padding = padtype_to_pads(onp.take(lhs.shape, lhs_perm)[2:],
|
|
onp.take(rhs.shape, rhs_perm)[2:],
|
|
window_strides, padding)
|
|
return conv_general_dilated_p.bind(
|
|
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
|
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
|
|
dimension_numbers=dimension_numbers, lhs_shape=lhs.shape,
|
|
rhs_shape=rhs.shape)
|
|
|
|
def dot(lhs, rhs): return dot_p.bind(lhs, rhs)
|
|
|
|
def dot_general(lhs, rhs, dimension_numbers):
|
|
lhs_dims, rhs_dims = dimension_numbers
|
|
dimension_numbers = (tuple(map(tuple, lhs_dims)), tuple(map(tuple, rhs_dims)))
|
|
return dot_general_p.bind(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
def broadcast(operand, sizes):
|
|
return broadcast_p.bind(operand, sizes=tuple(sizes))
|
|
|
|
def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
|
if operand.ndim == len(shape) and not len(broadcast_dimensions):
|
|
return operand
|
|
else:
|
|
return broadcast_in_dim_p.bind(
|
|
operand, shape=tuple(shape),
|
|
broadcast_dimensions=tuple(broadcast_dimensions))
|
|
|
|
def reshape(operand, new_sizes, dimensions=None):
|
|
same_shape = onp.shape(operand) == tuple(new_sizes)
|
|
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
|
if same_shape and same_dims:
|
|
return operand
|
|
else:
|
|
return reshape_p.bind(
|
|
operand, new_sizes=tuple(new_sizes),
|
|
dimensions=None if dimensions is None else tuple(dimensions),
|
|
old_sizes=onp.shape(operand))
|
|
|
|
def pad(operand, padding_value, padding_config):
|
|
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
|
|
|
def rev(operand, dimensions):
|
|
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
|
|
|
def select(pred, on_true, on_false):
|
|
return select_p.bind(pred, on_true, on_false)
|
|
|
|
def slice(operand, start_indices, limit_indices, strides=None):
|
|
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
|
limit_indices=tuple(limit_indices),
|
|
strides=None if strides is None else tuple(strides),
|
|
operand_shape=operand.shape)
|
|
|
|
def dynamic_slice(operand, start_indices, slice_sizes):
|
|
start_indices = _dynamic_slice_indices(operand, start_indices)
|
|
return dynamic_slice_p.bind(
|
|
operand, start_indices, slice_sizes=tuple(slice_sizes),
|
|
operand_shape=operand.shape)
|
|
|
|
def dynamic_update_slice(operand, update, start_indices):
|
|
start_indices = _dynamic_slice_indices(operand, start_indices)
|
|
return dynamic_update_slice_p.bind(operand, update, start_indices,
|
|
update_shape=update.shape)
|
|
|
|
def index_take(src, idxs, axes):
|
|
pvals = [_abstractify(arg) for arg in (src,) + idxs]
|
|
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_take, axes), pvals)
|
|
return index_take_p.bind(src, *idxs, axes=tuple(axes),
|
|
input_shape=src.shape, jaxpr=jaxpr, consts=consts)
|
|
|
|
def _index_take(axes, src, *idxs):
|
|
n = idxs[0].shape[0]
|
|
slice_sizes = subvals(src.shape, zip(axes, [1] * len(axes)))
|
|
|
|
def body_fun(i, state):
|
|
src, idxs, out = state
|
|
src_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
|
|
start_indices = subvals([0] * src.ndim, zip(axes, src_ind))
|
|
update = dynamic_slice(src, start_indices, slice_sizes)
|
|
update = reshape(update, (1,) + out.shape[1:])
|
|
out = dynamic_update_slice(out, update, [i] + [0] * (out.ndim - 1))
|
|
return src, idxs, out
|
|
|
|
out = full_like(src, 0, shape=(n,) + tuple(onp.delete(src.shape, axes)))
|
|
init_val = src, idxs, out
|
|
_, _, out = fori_loop(0, n, body_fun, init_val)
|
|
return out
|
|
|
|
def index_untake(src, dst, idxs, axes):
|
|
pvals = [_abstractify(arg) for arg in (src, dst) + idxs]
|
|
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_untake, axes), pvals)
|
|
return index_untake_p.bind(src, dst, *idxs, axes=tuple(axes),
|
|
jaxpr=jaxpr, consts=consts)
|
|
|
|
def _index_untake(axes, src, dst, *idxs):
|
|
n = idxs[0].shape[0]
|
|
slice_sizes = subvals(dst.shape, zip(axes, [1] * len(axes)))
|
|
|
|
def body_fun(i, state):
|
|
src, dst, idxs = state
|
|
vals = dynamic_slice(src, [i] + [0] * (src.ndim - 1), (1,) + src.shape[1:])
|
|
vals = reshape(vals, subvals(dst.shape, zip(axes, [1] * len(axes))))
|
|
dst_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
|
|
start_indices = subvals([0] * dst.ndim, zip(axes, dst_ind))
|
|
update = add(vals, dynamic_slice(dst, start_indices, slice_sizes))
|
|
dst = dynamic_update_slice(dst, update, start_indices)
|
|
return src, dst, idxs
|
|
|
|
init_val = src, dst, idxs
|
|
_, dst, _ = fori_loop(0, n, body_fun, init_val)
|
|
return dst
|
|
|
|
def transpose(operand, permutation):
|
|
return transpose_p.bind(operand, permutation=tuple(permutation))
|
|
|
|
def reduce(operand, init_value, computation, dimensions):
|
|
monoid_reducer = _get_monoid_reducer(computation, init_value)
|
|
if monoid_reducer:
|
|
return monoid_reducer(operand, dimensions)
|
|
else:
|
|
jaxpr, consts = _reduction_jaxpr(computation, init_value)
|
|
return reduce_p.bind(operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
dimensions=tuple(dimensions))
|
|
|
|
def _reduction_jaxpr(computation, init_value):
|
|
pval = _abstractify(init_value)
|
|
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(computation, (pval, pval))
|
|
return jaxpr, consts
|
|
|
|
def _get_monoid_reducer(monoid_op, x):
|
|
aval = core.get_aval(x)
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
if monoid_op is add:
|
|
return aval.val == 0 and _reduce_sum
|
|
elif monoid_op is max:
|
|
return aval.val == _get_max_identity(aval.dtype) and _reduce_max
|
|
elif monoid_op is min:
|
|
return aval.val == _get_min_identity(aval.dtype) and _reduce_min
|
|
|
|
def _get_max_identity(dtype):
|
|
if onp.issubdtype(dtype, onp.floating):
|
|
return onp.array(-onp.inf, dtype)
|
|
elif onp.issubdtype(dtype, onp.integer):
|
|
return onp.array(onp.iinfo(dtype).min, dtype)
|
|
|
|
def _get_min_identity(dtype):
|
|
if onp.issubdtype(dtype, onp.floating):
|
|
return onp.array(onp.inf, dtype)
|
|
elif onp.issubdtype(dtype, onp.integer):
|
|
return onp.array(onp.iinfo(dtype).max, dtype)
|
|
|
|
def _reduce_sum(operand, axes):
|
|
return reduce_sum_p.bind(operand, axes=tuple(axes), input_shape=operand.shape)
|
|
|
|
def _reduce_max(operand, axes):
|
|
return reduce_max_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_min(operand, axes):
|
|
return reduce_min_p.bind(operand, axes=tuple(axes))
|
|
|
|
def reduce_window(operand, init_value, computation, window_dimensions,
|
|
window_strides, padding):
|
|
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
|
|
if monoid_reducer:
|
|
return monoid_reducer(operand, window_dimensions, window_strides, padding)
|
|
else:
|
|
jaxpr, consts = _reduction_jaxpr(computation, init_value)
|
|
return reduce_window_p.bind(
|
|
operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def _get_monoid_window_reducer(monoid_op, x):
|
|
aval = core.get_aval(x)
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
if monoid_op is add:
|
|
return aval.val == 0 and _reduce_window_sum
|
|
elif monoid_op is max:
|
|
return aval.val == _get_max_identity(aval.dtype) and _reduce_window_max
|
|
elif monoid_op is min:
|
|
return aval.val == _get_min_identity(aval.dtype) and _reduce_window_min
|
|
|
|
def _reduce_window_sum(operand, window_dimensions, window_strides, padding):
|
|
return reduce_window_sum_p.bind(
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding,
|
|
input_shape=operand.shape)
|
|
|
|
def _reduce_window_max(operand, window_dimensions, window_strides, padding):
|
|
return reduce_window_max_p.bind(
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def _reduce_window_min(operand, window_dimensions, window_strides, padding):
|
|
return reduce_window_min_p.bind(
|
|
operand, window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def _select_and_scatter(operand, select, window_dimensions, window_strides,
|
|
padding, source, init_value, scatter):
|
|
select_jaxpr, select_consts = _reduction_jaxpr(select)
|
|
scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter)
|
|
return select_and_scatter_p.bind(
|
|
operand, source, init_value, select_jaxpr=select_jaxpr,
|
|
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
|
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def _select_and_scatter_add(source, operand, select_prim, window_dimensions,
|
|
window_strides, padding):
|
|
return select_and_scatter_add_p.bind(
|
|
source, operand, select_prim=select_prim,
|
|
window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def _select_and_gather_add(tangents, operand, select_prim, window_dimensions,
|
|
window_strides, padding):
|
|
return select_and_gather_add_p.bind(
|
|
tangents, operand, select_prim=select_prim,
|
|
window_dimensions=tuple(window_dimensions),
|
|
window_strides=tuple(window_strides), padding=padding)
|
|
|
|
def sort(operand, dimension=-1):
|
|
return sort_p.bind(operand, dimension=-1)
|
|
|
|
def sort_key_val(keys, values, dimension=-1):
|
|
# TODO new sort_key_val is variadic
|
|
result = sort_key_val_p.bind(keys, values, dimension=dimension)
|
|
sorted_keys, sorted_values = result
|
|
return sorted_keys, sorted_values
|
|
|
|
def _while_loop(cond_fun, body_fun, init_val):
|
|
init_val_flat, in_tree = tree_to_jaxtuples(init_val)
|
|
flat_body_fun, out_tree = flatten_fun(lu.wrap_init(body_fun), (in_tree,))
|
|
flat_cond_fun, _ = flatten_fun(lu.wrap_init(cond_fun), (in_tree,))
|
|
|
|
pval_flat = _abstractify(init_val_flat)
|
|
cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (pval_flat,))
|
|
body_jaxpr, pvout, body_consts = pe.trace_to_jaxpr(flat_body_fun, (pval_flat,))
|
|
abs_out, _ = pvout
|
|
|
|
params = OpaqueParam((abs_out, cond_jaxpr, cond_consts, body_jaxpr, body_consts))
|
|
out_flat = while_p.bind(init_val_flat, opaque_params=params)
|
|
if out_tree() != in_tree:
|
|
raise TypeError, "body_fun input and output must have identical structure"
|
|
return build_tree(out_tree(), out_flat)
|
|
|
|
class OpaqueParam(object):
|
|
__slots__ = ["val", "id"]
|
|
def __init__(self, val):
|
|
self.val = val
|
|
self.id = next(opaque_param_ids)
|
|
def __hash__(self):
|
|
return self.id
|
|
opaque_param_ids = itertools.count()
|
|
|
|
|
|
### convenience wrappers around traceables
|
|
|
|
|
|
def full_like(x, fill_value, dtype=None, shape=None):
|
|
"""Create a full array like np.full based on the example array `x`.
|
|
|
|
Args:
|
|
x: example array-like, used for shape and dtype information.
|
|
fill_value: a scalar value to fill the entries of the output array.
|
|
dtype: optional, a dtype parameter for the output ndarray.
|
|
shape: optional, a shape parameter for the output ndarray.
|
|
|
|
Returns:
|
|
An ndarray with the same shape as `x` with its entries set equal to
|
|
`fill_value`, similar to the output of np.full.
|
|
"""
|
|
shape = onp.shape(x) if shape is None else shape
|
|
return broadcast(onp.array(fill_value, dtype or _dtype(x)), shape)
|
|
|
|
|
|
def collapse(operand, start_dimension, stop_dimension):
|
|
lo, hi = start_dimension, stop_dimension
|
|
size = onp.product(operand.shape[lo:hi])
|
|
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
|
return reshape(operand, new_shape)
|
|
|
|
|
|
def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
|
|
"""Convenience wrapper around slice applying to only one dimension."""
|
|
start_indices = [0] * operand.ndim
|
|
limit_indices = list(operand.shape)
|
|
strides = [1] * operand.ndim
|
|
|
|
start_indices[axis] = start_index
|
|
limit_indices[axis] = limit_index
|
|
strides[axis] = stride
|
|
|
|
return slice(operand, start_indices, limit_indices, strides)
|
|
|
|
|
|
def index_in_dim(operand, index, axis=0, keepdims=True):
|
|
"""Convenience wrapper around slice to perform int indexing."""
|
|
axis_size = operand.shape[axis]
|
|
wrapped_index = index + axis_size if index < 0 else index
|
|
if not 0 <= wrapped_index < axis_size:
|
|
msg = 'index {} is out of bounds for axis {} with size {}'
|
|
raise IndexError(msg.format(index, axis, axis_size))
|
|
result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
|
|
if keepdims:
|
|
return result
|
|
else:
|
|
return reshape(result, onp.delete(operand.shape, axis))
|
|
|
|
|
|
def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
|
|
"""Convenience wrapper around dynamic_slice applying to one dimension."""
|
|
start_indices = [onp.array([0])] * operand.ndim
|
|
slice_sizes = list(operand.shape)
|
|
|
|
start_indices[axis] = reshape(rem(start_index, operand.shape[axis]), [1])
|
|
slice_sizes[axis] = slice_size
|
|
|
|
start_indices = concatenate(start_indices, 0)
|
|
return dynamic_slice(operand, start_indices, slice_sizes)
|
|
|
|
|
|
def dynamic_index_in_dim(operand, index, axis=0, keepdims=True):
|
|
"""Convenience wrapper around dynamic_slice to perform int indexing."""
|
|
result = dynamic_slice_in_dim(operand, index, 1, axis)
|
|
if keepdims:
|
|
return result
|
|
else:
|
|
return reshape(result, onp.delete(operand.shape, axis))
|
|
|
|
|
|
def dynamic_update_slice_in_dim(operand, update, start_index, axis):
|
|
start_indices = [0] * _ndim(operand)
|
|
start_indices[axis] = start_index % operand.shape[axis]
|
|
return dynamic_update_slice(operand, update, start_indices)
|
|
|
|
|
|
def dynamic_update_index_in_dim(operand, update, index, axis):
|
|
if _ndim(update) != _ndim(operand):
|
|
assert _ndim(update) + 1 == _ndim(operand)
|
|
ax = axis % _ndim(operand)
|
|
update = reshape(update, operand.shape[:ax] + (1,) + operand.shape[ax:])
|
|
return dynamic_update_slice_in_dim(operand, update, index, axis)
|
|
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
"""Loop from `lower` to `upper` by reduction to `while_loop`.
|
|
|
|
Arguments:
|
|
lower: loop index lower bound (inclusive)
|
|
upper: loop index upper bound (exclusive)
|
|
body_fun: function of type (int, T) -> T, where T is the type of `init_val`
|
|
init_val: initial loop value, of type T
|
|
|
|
Returns:
|
|
Loop value from the final iteration, of type T.
|
|
"""
|
|
# state: (upper limit, index, loop value)
|
|
# The `lt` and `add` functions are added to the namespace programmatically.
|
|
_, _, result = _while_loop(
|
|
lambda (upper, i, _): lt(i, upper),
|
|
lambda (upper, i, x): (upper, add(i, 1), body_fun(i, x)),
|
|
(upper, lower, init_val))
|
|
return result
|
|
|
|
|
|
def foreach_loop(sequence, body_fun, init_val):
|
|
"""Loop over `sequence` by reduction to `while_loop`.
|
|
|
|
Arguments:
|
|
sequence: tuple of loop items, each of type U
|
|
body_fun: function of type (U, T) -> T, where T is the type of `init_val`
|
|
init_val: initial loop value, of type T
|
|
|
|
Returns:
|
|
Loop value from the final iteration, of type T.
|
|
"""
|
|
_, result = fori_loop(
|
|
0, len(sequence),
|
|
lambda i, (seq, val): body_fun(seq[i], val),
|
|
(sequence, init_val))
|
|
return result
|
|
|
|
|
|
def batch_matmul(lhs, rhs):
|
|
"""Batch matrix multiplication."""
|
|
if _min(lhs.ndim, rhs.ndim) < 2:
|
|
raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
|
|
.format(lhs.ndim, rhs.ndim))
|
|
if lhs.ndim != rhs.ndim:
|
|
raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
|
|
.format(lhs.ndim, rhs.ndim))
|
|
lhs_contract = (lhs.ndim - 1,)
|
|
rhs_contract = (rhs.ndim - 2,)
|
|
batch = tuple(range(lhs.ndim - 2))
|
|
return dot_general(lhs, rhs, [(lhs_contract, rhs_contract), (batch, batch)])
|
|
|
|
|
|
# These trig functions also exist in the XLA client library, but we treat them
|
|
# as non-primitive to maintain a smaller set of autodiff primitives.
|
|
|
|
def sqrt(x):
|
|
return pow(x, _const(x, 0.5))
|
|
|
|
def rsqrt(x):
|
|
return pow(x, _const(x, -0.5))
|
|
|
|
def square(x):
|
|
return mul(x, x)
|
|
|
|
def reciprocal(x):
|
|
return div(_const(x, 1.), x)
|
|
|
|
def tan(x):
|
|
return div(sin(x), cos(x))
|
|
|
|
def asin(x):
|
|
# asin(x) = 2 * atan(x / (1 + sqrt(1 - x**2)))
|
|
return mul(_const(x, 2.),
|
|
atan2(x, add(_const(x, 1.), sqrt(add(_const(x, 1.), square(x))))))
|
|
|
|
def acos(x):
|
|
# acos(x) = 2 * atan(sqrt(1 - x**2) / (1 + x))
|
|
return mul(_const(x, 2.),
|
|
atan2(sqrt(sub(_const(x, 1.), square(x))), add(_const(x, 1.), x)))
|
|
|
|
def atan(x):
|
|
return atan2(x, _const(x, 1.))
|
|
|
|
def sinh(x):
|
|
return mul(_const(x, 0.5), sub(exp(x), exp(neg(x))))
|
|
|
|
def cosh(x):
|
|
return mul(_const(x, 0.5), add(exp(x), exp(neg(x))))
|
|
|
|
def asinh(x):
|
|
# asinh(x) = log(x + sqrt(x**2 + 1))
|
|
return log(add(x, sqrt(add(mul(x, x), _const(x, 1.)))))
|
|
|
|
def acosh(x):
|
|
# acosh(x) = log(x + sqrt((x + 1) * (x - 1)))
|
|
return log(add(x, mul(sqrt(add(x, _const(x, 1.))),
|
|
sqrt(sub(x, _const(x, 1.))))))
|
|
|
|
|
|
# Add some methods to ShapedArray that rely on lax primitives
|
|
|
|
ShapedArray.broadcast = core.aval_method(broadcast)
|
|
ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy
|
|
ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy
|
|
|
|
def _iter(tracer):
|
|
if tracer.ndim == 0:
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
else:
|
|
n = tracer.shape[0]
|
|
return (index_in_dim(tracer, i, keepdims=False) for i in xrange(n))
|
|
ShapedArray._iter = staticmethod(_iter)
|
|
|
|
# Add some ad handlers that use (or could use) lax primitives
|
|
|
|
def zeros_like_array(x):
|
|
dtype = xla_bridge.canonicalize_dtype(_dtype(x))
|
|
return onp.broadcast_to(onp.zeros((), dtype), onp.shape(x))
|
|
|
|
for t in itertools.chain(array_types, [xla.DeviceArray]):
|
|
ad_util.jaxval_adders[t] = add
|
|
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
|
|
|
batching.pytype_aval_mappings[xla.DeviceArray] = make_shaped_array
|
|
|
|
|
|
### primitives
|
|
|
|
|
|
_input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
|
|
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
|
|
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
|
|
|
def identity(x): return x
|
|
|
|
|
|
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
|
prim = Primitive(name)
|
|
prim.def_impl(partial(xla.apply_primitive, prim))
|
|
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
|
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
|
return prim
|
|
|
|
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
|
|
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
|
least_specialized = _max(
|
|
map(type, args), key=operator.attrgetter('array_abstraction_level'))
|
|
if least_specialized is ConcreteArray:
|
|
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
|
elif least_specialized is ShapedArray:
|
|
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
|
elif least_specialized is UnshapedArray:
|
|
return UnshapedArray(dtype_rule(*args, **kwargs))
|
|
else:
|
|
raise TypeError(args, least_specialized)
|
|
|
|
|
|
def standard_translate(name, c, *args, **kwargs):
|
|
xla_opname = ''.join(term.capitalize() for term in name.split('_'))
|
|
return getattr(c, xla_opname)(*args, **kwargs)
|
|
|
|
|
|
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval):
|
|
if not any(onp.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
|
msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
|
typename = str(onp.dtype(aval.dtype).name)
|
|
accepted_typenames = (str(onp.dtype(t).name) for t in accepted_dtypes)
|
|
raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
|
|
return result_dtype(aval.dtype)
|
|
|
|
|
|
def unop(result_dtype, accepted_dtypes, name):
|
|
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
|
prim = standard_primitive(operator.attrgetter('shape'), dtype_rule, name)
|
|
batching.defvectorized(prim)
|
|
return prim
|
|
standard_unop = partial(unop, identity)
|
|
|
|
|
|
def binop_dtype_rule(result_dtype, accepted_dtypes, name, *avals):
|
|
aval_dtypes = [aval.dtype for aval in avals]
|
|
for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)):
|
|
if not any(onp.issubdtype(aval_dtype, t) for t in types):
|
|
msg = ('{} does not accept dtype {} at position {}. '
|
|
'Accepted dtypes at position {} are subtypes of {}.')
|
|
typename = str(onp.dtype(aval_dtype).name)
|
|
typenames = ', '.join(str(onp.dtype(t).name) for t in types)
|
|
raise TypeError(msg.format(name, typename, i, i, typenames))
|
|
_check_same_dtypes(name, False, *aval_dtypes)
|
|
return result_dtype(*avals)
|
|
|
|
|
|
def broadcasting_shape_rule(name, *avals):
|
|
shapes = onp.array([aval.shape for aval in avals if aval.shape])
|
|
if not shapes.size:
|
|
return ()
|
|
if len({len(shape) for shape in shapes}) != 1:
|
|
msg = '{} got arrays of different rank: {}.'
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
|
result_shape = onp.max(shapes, axis=0)
|
|
if not onp.all((shapes == result_shape) | (shapes == 1)):
|
|
msg = '{} got incompatible shapes for broadcasting: {}.'
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
|
return tuple(result_shape)
|
|
|
|
|
|
def binop(result_dtype, accepted_dtypes, name):
|
|
dtype_rule = partial(binop_dtype_rule, result_dtype, accepted_dtypes, name)
|
|
shape_rule = partial(broadcasting_shape_rule, name)
|
|
prim = standard_primitive(shape_rule, dtype_rule, name)
|
|
batching.defbroadcasting(prim)
|
|
return prim
|
|
standard_binop = partial(binop, _input_dtype)
|
|
|
|
|
|
# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
|
|
# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
|
|
# a broadcast). but saving the shape info with the primitives isn't great either
|
|
# because then we can't trace these ops without shape data.
|
|
def _brcast(x, *others):
|
|
# used in jvprules to make binop broadcasting explicit for transposability.
|
|
# requires shape info during jvp tracing, which isn't strictly necessary.
|
|
shapes = filter(None, map(onp.shape, (x,) + others))
|
|
shape = tuple(shapes and onp.max(shapes, axis=0))
|
|
if onp.shape(x) != shape:
|
|
return _brcast_to(x, shape)
|
|
else:
|
|
return x
|
|
|
|
|
|
def _brcast_to(x, shape):
|
|
x_shape = onp.shape(x)
|
|
assert x_shape != shape
|
|
if x_shape:
|
|
assert len(x_shape) == len(shape)
|
|
broadcast_dimensions, = onp.where(onp.equal(x_shape, shape))
|
|
squeezed_dimensions, = onp.where(onp.not_equal(x_shape, shape))
|
|
inshape = onp.delete(x_shape, squeezed_dimensions)
|
|
return broadcast_in_dim(reshape(x, inshape), shape, broadcast_dimensions)
|
|
else:
|
|
return broadcast(x, shape)
|
|
|
|
|
|
_f32 = {onp.float32}
|
|
_float = {onp.floating}
|
|
_complex = {onp.complex64}
|
|
_int = {onp.integer}
|
|
_bool = {onp.bool_}
|
|
|
|
_num = _int | _float | _complex
|
|
_any = _int | _float | _complex | _bool
|
|
|
|
|
|
neg_p = standard_unop(_num, 'neg')
|
|
ad.deflinear(neg_p, lambda t: [neg(t)])
|
|
batching.defvectorized(neg_p)
|
|
|
|
sign_p = standard_unop(_num, 'sign')
|
|
ad.defjvp_zero(sign_p)
|
|
|
|
floor_p = standard_unop(_float, 'floor')
|
|
ad.defjvp_zero(floor_p)
|
|
|
|
ceil_p = standard_unop(_float, 'ceil')
|
|
ad.defjvp_zero(ceil_p)
|
|
|
|
round_p = standard_unop(_float, 'round')
|
|
ad.defjvp_zero(round_p)
|
|
|
|
is_finite_p = unop(_fixed_dtype(onp.bool_), _float, 'is_finite')
|
|
ad.defjvp_zero(is_finite_p)
|
|
|
|
exp_p = standard_unop(_float | _complex, 'exp')
|
|
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
|
|
|
log_p = standard_unop(_float | _complex, 'log')
|
|
ad.defjvp(log_p, lambda g, x: div(g, x))
|
|
|
|
expm1_p = standard_unop(_float | _complex, 'expm1')
|
|
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
|
|
|
log1p_p = standard_unop(_float | _complex, 'log1p')
|
|
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
|
|
|
tanh_p = standard_unop(_float | _complex, 'tanh')
|
|
ad.defjvp(tanh_p, lambda g, x: div(g, pow(cosh(x), _two(x))))
|
|
|
|
sin_p = standard_unop(_float | _complex, 'sin')
|
|
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
|
|
|
cos_p = standard_unop(_float | _complex, 'cos')
|
|
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
|
|
|
atan2_p = standard_binop([_float, _float], 'atan2')
|
|
|
|
lgamma_p = standard_unop(_float, 'lgamma')
|
|
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
|
|
|
digamma_p = standard_unop(_float, 'digamma')
|
|
|
|
erf_p = standard_unop(_float, 'erf')
|
|
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / onp.sqrt(onp.pi)),
|
|
mul(g, exp(neg(square(x))))))
|
|
|
|
erfc_p = standard_unop(_float, 'erfc')
|
|
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / onp.sqrt(onp.pi)),
|
|
mul(neg(g), exp(neg(square(x))))))
|
|
|
|
erf_inv_p = standard_unop(_float, 'erf_inv')
|
|
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, onp.sqrt(onp.pi) / 2.),
|
|
mul(g, exp(square(ans)))))
|
|
|
|
real_p = unop(_fixed_dtype(onp.float32), _complex, 'real')
|
|
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), onp.float32))])
|
|
|
|
imag_p = unop(_fixed_dtype(onp.float32), _complex, 'imag')
|
|
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), onp.float32), neg(t))])
|
|
|
|
complex_p = standard_binop([_f32, _f32], 'complex')
|
|
ad.deflinear(complex_p, lambda t: [real(t), imag(t)])
|
|
|
|
# TODO promotes dtypes, need to remember whether we came from float or not
|
|
conj_p = unop(_fixed_dtype(onp.complex64), _float | _complex, 'conj')
|
|
ad.deflinear(conj_p, lambda t: [conj(t)])
|
|
|
|
abs_p = unop(_complex_basetype, _num, 'abs')
|
|
ad.defjvp2(abs_p,
|
|
lambda g, ans, x: div(_maybe_real(mul(g, _maybe_conj(x))),
|
|
_replace_zero(ans)))
|
|
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
|
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
|
|
|
# TODO handle broadcasting
|
|
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
|
|
ad.defjvp(pow_p,
|
|
lambda g, x, y: mul(_brcast(g, y), mul(y, pow(x, select(
|
|
eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))),
|
|
lambda g, x, y: mul(_brcast(g, x),
|
|
mul(log(_replace_zero(x)), pow(x, y))))
|
|
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
|
|
|
not_p = standard_unop(_int | _bool, 'not')
|
|
|
|
and_p = standard_binop([_any, _any], 'and')
|
|
ad.defjvp_zero(and_p)
|
|
|
|
or_p = standard_binop([_any, _any], 'or')
|
|
ad.defjvp_zero(or_p)
|
|
|
|
xor_p = standard_binop([_any, _any], 'xor')
|
|
ad.defjvp_zero(xor_p)
|
|
|
|
add_p = standard_binop([_num, _num], 'add')
|
|
ad.defjvp(add_p, lambda g, x, y: _brcast(g, y), lambda g, x, y: _brcast(g, x))
|
|
|
|
sub_p = standard_binop([_num, _num], 'sub')
|
|
ad.defjvp(sub_p,
|
|
lambda g, x, y: _brcast(g, y),
|
|
lambda g, x, y: _brcast(neg(g), x))
|
|
|
|
mul_p = standard_binop([_num, _num], 'mul')
|
|
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
|
|
|
|
|
|
def div_transpose_rule(cotangent, x, y):
|
|
assert x is None
|
|
res = ad_util.zero if cotangent is ad_util.zero else div(cotangent, y)
|
|
return res, None
|
|
div_p = standard_binop([_num, _num], 'div')
|
|
ad.defjvp(div_p,
|
|
lambda g, x, y: div(_brcast(g, y), y),
|
|
lambda g, x, y: div(mul(neg(_brcast(g, x)), x), pow(y, _two(y))))
|
|
ad.primitive_transposes[div_p] = div_transpose_rule
|
|
|
|
rem_p = standard_binop([_num, _num], 'rem')
|
|
ad.defjvp(rem_p,
|
|
lambda g, x, y: _brcast(g, y),
|
|
lambda g, x, y: mul(neg(g), floor(div(x, y))))
|
|
|
|
|
|
max_p = standard_binop([_any, _any], 'max')
|
|
ad.defjvp2(max_p,
|
|
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
|
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
|
|
|
min_p = standard_binop([_any, _any], 'min')
|
|
ad.defjvp2(min_p,
|
|
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
|
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
|
|
|
|
|
shift_left_p = standard_binop([_int, _int], 'shift_left')
|
|
ad.defjvp_zero(shift_left_p)
|
|
|
|
shift_right_arithmetic_p = standard_binop([_int, _int], 'shift_right_arithmetic')
|
|
ad.defjvp_zero(shift_right_arithmetic_p)
|
|
|
|
shift_right_logical_p = standard_binop([_int, _int], 'shift_right_logical')
|
|
ad.defjvp_zero(shift_right_logical_p)
|
|
|
|
eq_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'eq')
|
|
ad.defjvp_zero(eq_p)
|
|
|
|
ne_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'ne')
|
|
ad.defjvp_zero(ne_p)
|
|
|
|
ge_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'ge')
|
|
ad.defjvp_zero(ge_p)
|
|
|
|
gt_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'gt')
|
|
ad.defjvp_zero(gt_p)
|
|
|
|
le_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'le')
|
|
ad.defjvp_zero(le_p)
|
|
|
|
lt_p = binop(_fixed_dtype(onp.bool_), [_any, _any], 'lt')
|
|
ad.defjvp_zero(lt_p)
|
|
|
|
|
|
def convert_element_type_shape_rule(operand, new_dtype, old_dtype):
|
|
return operand.shape
|
|
|
|
def convert_element_type_dtype_rule(operand, new_dtype, old_dtype):
|
|
return new_dtype
|
|
|
|
def convert_element_type_translation_rule(c, operand, new_dtype, old_dtype):
|
|
new_etype = xla_bridge.dtype_to_etype(new_dtype)
|
|
return c.ConvertElementType(operand, new_element_type=new_etype)
|
|
|
|
convert_element_type_p = standard_primitive(
|
|
convert_element_type_shape_rule, convert_element_type_dtype_rule,
|
|
'convert_element_type', convert_element_type_translation_rule)
|
|
ad.deflinear(
|
|
convert_element_type_p,
|
|
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
|
|
batching.defvectorized(convert_element_type_p)
|
|
|
|
|
|
def bitcast_convert_type_shape_rule(operand, new_dtype):
|
|
return operand.shape
|
|
|
|
def bitcast_convert_type_dtype_rule(operand, new_dtype):
|
|
return new_dtype
|
|
|
|
def bitcast_convert_type_translation_rule(c, operand, new_dtype):
|
|
new_etype = xla_bridge.dtype_to_etype(new_dtype)
|
|
return c.BitcastConvertType(operand, new_element_type=new_etype)
|
|
|
|
bitcast_convert_type_p = standard_primitive(
|
|
bitcast_convert_type_shape_rule, bitcast_convert_type_dtype_rule,
|
|
'bitcast_convert_type', bitcast_convert_type_translation_rule)
|
|
ad.defjvp_zero(bitcast_convert_type_p)
|
|
batching.defvectorized(bitcast_convert_type_p)
|
|
|
|
|
|
def conv_general_dilated_shape_rule(
|
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
dimension_numbers=None, **unused_kwargs):
|
|
if dimension_numbers is None:
|
|
lhs_dilated = _dilate_shape(lhs.shape, lhs_dilation)
|
|
rhs_dilated = _dilate_shape(rhs.shape, rhs_dilation)
|
|
_check_conv_shapes('conv_general_dilated', lhs_dilated, rhs_dilated,
|
|
window_strides)
|
|
return conv_shape_tuple(lhs_dilated, rhs_dilated, window_strides, padding)
|
|
else:
|
|
if not isinstance(dimension_numbers, (tuple, list)):
|
|
msg = "conv_general_dilated dimension_numbers must be tuple/list, got {}."
|
|
raise TypeError(msg.format(type(dimension_numbers)))
|
|
if len(dimension_numbers) != 3:
|
|
msg = "conv_general_dilated dimension_numbers must be length 3, got {}."
|
|
raise TypeError(msg.format(len(dimension_numbers)))
|
|
if not all(isinstance(elt, str) for elt in dimension_numbers):
|
|
msg = ("conv_general_dilated dimension_numbers elements must be strings, "
|
|
"got {}.")
|
|
raise TypeError(msg.format(tuple(map(type, dimension_numbers))))
|
|
msg = ("conv_general_dilated dimension_numbers[{}] must have len equal to "
|
|
"the ndim of lhs and rhs, got {} for lhs and rhs shapes {} and {}.")
|
|
for i, elt in enumerate(dimension_numbers):
|
|
if len(elt) != lhs.ndim:
|
|
raise TypeError(msg.format(i, len(elt), lhs.shape, rhs.shape))
|
|
|
|
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
|
lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation)
|
|
rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation)
|
|
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
|
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
|
|
|
|
def conv_general_dilated_dtype_rule(
|
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
dimension_numbers, **unused_kwargs):
|
|
return binop_dtype_rule(_input_dtype, [_f32, _f32], 'conv_general_dilated',
|
|
lhs, rhs)
|
|
|
|
def conv_general_dilated_transpose_lhs(
|
|
g, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
dimension_numbers, lhs_shape, rhs_shape):
|
|
if dimension_numbers is None:
|
|
nd = len(lhs_shape)
|
|
lhs_sdims = rhs_sdims = out_sdims = list(range(2, nd))
|
|
trans_dimension_numbers = ConvolutionDimensionNumbers(
|
|
tuple(range(nd)), (1, 0) + tuple(range(2, nd)), tuple(range(nd)))
|
|
else:
|
|
lhs_sdims, rhs_sdims, out_sdims = _get_sdims(dimension_numbers)
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
trans_dimension_numbers = out_spec, _charswap("I", "O", rhs_spec), lhs_spec
|
|
|
|
padding = _conv_general_vjp_lhs_padding(
|
|
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
|
|
window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
|
|
rhs_dilation)
|
|
revd_weights = rev(rhs, rhs_sdims)
|
|
return conv_general_dilated(
|
|
g, revd_weights, window_strides=lhs_dilation, padding=padding,
|
|
lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
|
|
dimension_numbers=trans_dimension_numbers)
|
|
|
|
|
|
def conv_general_dilated_transpose_rhs(
|
|
g, lhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
dimension_numbers, lhs_shape, rhs_shape):
|
|
if dimension_numbers is None:
|
|
nd = len(lhs_shape)
|
|
lhs_sdims = rhs_sdims = out_sdims = list(range(2, nd))
|
|
trans_dimension_numbers = ConvolutionDimensionNumbers(
|
|
(1, 0) + tuple(range(2, nd)),
|
|
(1, 0) + tuple(range(2, nd)),
|
|
(1, 0) + tuple(range(2, nd)))
|
|
else:
|
|
lhs_sdims, rhs_sdims, out_sdims = _get_sdims(dimension_numbers)
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
trans_dimension_numbers = (_charswap("C", "N", lhs_spec),
|
|
out_spec.translate(string.maketrans("NC", "IO")),
|
|
rhs_spec.translate(string.maketrans("IO", "NC")))
|
|
|
|
padding = _conv_general_vjp_rhs_padding(
|
|
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
|
|
window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
|
|
rhs_dilation)
|
|
return conv_general_dilated(
|
|
lhs, g, window_strides=rhs_dilation, padding=padding,
|
|
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
|
|
dimension_numbers=trans_dimension_numbers)
|
|
|
|
|
|
def conv_general_dilated_translation_rule(
|
|
c, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
dimension_numbers, **unused_kwargs):
|
|
if isinstance(dimension_numbers, ConvolutionDimensionNumbers):
|
|
dimension_numbers = _conv_general_proto(dimension_numbers)
|
|
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
|
rhs_dilation, dimension_numbers)
|
|
|
|
conv_general_dilated_p = standard_primitive(
|
|
conv_general_dilated_shape_rule, conv_general_dilated_dtype_rule,
|
|
'conv_general_dilated', conv_general_dilated_translation_rule)
|
|
ad.defbilinear(conv_general_dilated_p,
|
|
conv_general_dilated_transpose_lhs,
|
|
conv_general_dilated_transpose_rhs)
|
|
|
|
|
|
def dot_shape_rule(lhs, rhs):
|
|
if lhs.ndim == 0 or rhs.ndim == 0:
|
|
msg = "Dot only supports rank 1 or above, got shapes {} and {}."
|
|
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
|
if lhs.ndim > 2 or rhs.ndim > 2:
|
|
msg = "Dot only supports rank 2 or less, got shapes {} and {}."
|
|
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
|
|
|
def require(shape_cond):
|
|
if not shape_cond:
|
|
msg = "Incompatible shapes for dot: got {} and {}."
|
|
raise TypeError(msg.format(lhs.shape, rhs.shape))
|
|
|
|
if lhs.ndim == rhs.ndim == 1:
|
|
require(lhs.shape == rhs.shape)
|
|
return ()
|
|
elif lhs.ndim == rhs.ndim == 2:
|
|
require(lhs.shape[1] == rhs.shape[0])
|
|
return (lhs.shape[0], rhs.shape[1])
|
|
elif rhs.ndim == 1:
|
|
require(lhs.shape[-1] == rhs.shape[0])
|
|
return lhs.shape[:-1]
|
|
else:
|
|
require(lhs.shape[-1] == rhs.shape[-2])
|
|
return lhs.shape[:-1] + rhs.shape[:-2] + rhs.shape[-1:]
|
|
|
|
def dot_transpose_lhs(t, rhs):
|
|
if onp.ndim(t) == onp.ndim(rhs) == 2:
|
|
return dot(t, transpose(rhs, (1, 0)))
|
|
elif onp.ndim(t) == 1 and onp.ndim(rhs) == 2:
|
|
return dot(rhs, t)
|
|
elif onp.ndim(t) == onp.ndim(rhs) == 1:
|
|
return _outer(t, rhs)
|
|
elif onp.ndim(t) == 0 or onp.ndim(rhs) == 0:
|
|
return mul(t, rhs)
|
|
else:
|
|
raise TypeError
|
|
|
|
def dot_transpose_rhs(t, lhs):
|
|
if onp.ndim(lhs) == onp.ndim(t) == 2:
|
|
return dot(transpose(lhs, (1, 0)), t)
|
|
elif onp.ndim(lhs) == 2 and onp.ndim(t) == 1:
|
|
return dot(t, lhs)
|
|
elif onp.ndim(t) == onp.ndim(lhs) == 1:
|
|
return _outer(lhs, t)
|
|
elif onp.ndim(t) == 0 or onp.ndim(lhs) == 0:
|
|
return mul(t, lhs)
|
|
else:
|
|
raise TypeError
|
|
|
|
def _outer(x, y):
|
|
assert onp.ndim(x) == onp.ndim(y) == 1
|
|
return mul(reshape(x, (x.shape[0], 1)), reshape(y, (1, y.shape[0])))
|
|
|
|
def dot_batch_rule(batched_args, batch_dims):
|
|
lhs, rhs = batched_args
|
|
lbd, rbd = batch_dims
|
|
T = lambda x: transpose(x, onp.arange(onp.ndim(x))[::-1])
|
|
|
|
if max(onp.ndim(lhs), onp.ndim(rhs)) <= 2:
|
|
if rbd is None:
|
|
assert lbd in (0, 1)
|
|
if lbd == 0:
|
|
return dot(lhs, rhs), 0
|
|
else:
|
|
return dot(T(rhs), lhs), 1
|
|
|
|
if lbd is None:
|
|
assert rbd in (0, 1)
|
|
if rbd == onp.ndim(rhs) - 1:
|
|
return dot(lhs, rhs), 1
|
|
else:
|
|
return dot(rhs, T(lhs)), 0
|
|
|
|
assert False # unreachable
|
|
|
|
if lbd is None:
|
|
assert rbd is not None
|
|
lhs = broadcast(lhs, (rhs.shape[rbd],))
|
|
else:
|
|
lhs = batching.move_dim_to_front(lhs, lbd)
|
|
lhs_batch = (0,)
|
|
lhs_contracting = (onp.ndim(lhs) - 1,)
|
|
|
|
if rbd is None:
|
|
assert lbd is not None
|
|
rhs = broadcast(rhs, (lhs.shape[lbd],))
|
|
else:
|
|
rhs = batching.move_dim_to_front(rhs, rbd)
|
|
rhs_batch = (0,)
|
|
rhs_contracting = (onp.arange(1, onp.ndim(rhs))[-2:][0],)
|
|
|
|
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
|
|
return dot_general(lhs, rhs, dim_nums), 0
|
|
|
|
dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
|
|
dot_p = standard_primitive(dot_shape_rule, dot_dtype_rule, 'dot')
|
|
ad.defbilinear(dot_p, dot_transpose_lhs, dot_transpose_rhs)
|
|
batching.primitive_batchers[dot_p] = dot_batch_rule
|
|
|
|
|
|
def dot_general_shape_rule(lhs, rhs, dimension_numbers):
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
if len(lhs_batch) != len(rhs_batch):
|
|
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
|
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
|
if not onp.all(onp.equal(lhs_batch, rhs_batch)):
|
|
msg = ("dot_general requires same lhs and rhs batch dimension numbers, "
|
|
"got {} and {}.")
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
|
lhs_batch_shape = onp.take(lhs.shape, lhs_batch)
|
|
rhs_batch_shape = onp.take(rhs.shape, rhs_batch)
|
|
if not onp.all(onp.equal(lhs_batch_shape, rhs_batch_shape)):
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
|
"to have the same shape, got {} and {}.")
|
|
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
|
if tuple(sorted(lhs_batch)) != tuple(range(len(lhs_batch))):
|
|
msg = ("dot_general requires lhs batch dimensions to precede contracting "
|
|
"and non-contracting dimensions, got lhs_batch {}.")
|
|
raise TypeError(msg.format(lhs_batch))
|
|
if tuple(sorted(rhs_batch)) != tuple(range(len(rhs_batch))):
|
|
msg = ("dot_general requires rhs batch dimensions to precede contracting "
|
|
"and non-contracting dimensions, got rhs_batch {}.")
|
|
raise TypeError(msg.format(rhs_batch))
|
|
if not len(lhs_contracting) == len(rhs_contracting) == 1:
|
|
msg = ("dot_general accepts exactly one lhs_contracting and "
|
|
"rhs_contracting dimension, got {} and {}.")
|
|
raise TypeError(msg.format(lhs_contracting, rhs_contracting))
|
|
lhs_contracting_shape = onp.take(lhs.shape, lhs_contracting)
|
|
rhs_contracting_shape = onp.take(rhs.shape, rhs_contracting)
|
|
if not onp.all(onp.equal(lhs_contracting_shape, rhs_contracting_shape)):
|
|
msg = ("dot_general requires contracting dimensions to have the same "
|
|
"shape, got {} and {}.")
|
|
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
|
if lhs.ndim > len(lhs_batch) + len(lhs_contracting) + 1:
|
|
msg = ("dot_general requires either one or zero non-batch non-contracting "
|
|
"lhs dimension, got {}.")
|
|
diff = lhs.ndim - len(lhs_batch) - len(lhs_contracting)
|
|
raise TypeError(msg.format(diff))
|
|
if rhs.ndim > len(rhs_batch) + len(rhs_contracting) + 1:
|
|
msg = ("dot_general requires either one or zero non-batch non-contracting "
|
|
"rhs dimension, got {}.")
|
|
diff = rhs.ndim - len(rhs_batch) - len(rhs_contracting)
|
|
raise TypeError(msg.format(diff))
|
|
|
|
batch_shape = tuple(onp.take(lhs.shape, lhs_batch))
|
|
lhs_contract_or_batch = tuple(lhs_contracting) + tuple(lhs_batch)
|
|
lhs_tensored_shape = tuple(onp.delete(lhs.shape, lhs_contract_or_batch))
|
|
rhs_contract_or_batch = tuple(rhs_contracting) + tuple(rhs_batch)
|
|
rhs_tensored_shape = tuple(onp.delete(rhs.shape, rhs_contract_or_batch))
|
|
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
|
|
|
|
|
def dot_general_dtype_rule(lhs, rhs, dimension_numbers):
|
|
return binop_dtype_rule(_input_dtype, [_num, _num], 'dot_general', lhs, rhs)
|
|
|
|
|
|
def dot_general_transpose_lhs(g, y, dimension_numbers, swap_ans=False):
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
|
y_kept = remaining(range(y.ndim), y_contract, y_batch)
|
|
if swap_ans:
|
|
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
|
else:
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
|
x_contract_sorted_by_y = list(onp.take(x_contract, onp.argsort(y_contract)))
|
|
out_axes = onp.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
|
return transpose(dot_general(g, y, dims), tuple(out_axes))
|
|
|
|
def dot_general_transpose_rhs(g, x, dimension_numbers):
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
|
return dot_general_transpose_lhs(g, x, swapped_dimension_numbers, True)
|
|
|
|
|
|
# def dot_general_batch_rule(batched_args, batch_dims, dimension_numbers):
|
|
# assert False # TODO
|
|
|
|
dot_general_p = standard_primitive(dot_general_shape_rule,
|
|
dot_general_dtype_rule, 'dot_general')
|
|
ad.defbilinear(dot_general_p,
|
|
dot_general_transpose_lhs, dot_general_transpose_rhs)
|
|
# batching.primitive_batchers[dot_general_p] = dot_general_batch_rule
|
|
|
|
|
|
def broadcast_shape_rule(operand, sizes):
|
|
_check_shapelike('broadcast', 'sizes', sizes)
|
|
return tuple(sizes) + operand.shape
|
|
|
|
def broadcast_batch_rule(batched_args, batch_dims, sizes):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
new_bdim = None if bdim is None else bdim + len(sizes)
|
|
return broadcast(operand, sizes), new_bdim
|
|
|
|
broadcast_p = standard_primitive(
|
|
broadcast_shape_rule, _input_dtype, 'broadcast')
|
|
ad.deflinear(broadcast_p, lambda t, sizes: [_reduce_sum(t, range(len(sizes)))])
|
|
batching.primitive_batchers[broadcast_p] = broadcast_batch_rule
|
|
|
|
|
|
def broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions):
|
|
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
|
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
|
broadcast_dimensions)
|
|
if operand.ndim != len(broadcast_dimensions):
|
|
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
|
'operand ndim, got broadcast_dimensions for operand ndim {}.')
|
|
raise TypeError(msg.format(broadcast_dimensions, operand.ndim))
|
|
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
'dimensions, got {} for operand ndim {} and shape {}.')
|
|
raise TypeError(msg.format(broadcast_dimensions, operand.ndim, shape))
|
|
return shape
|
|
|
|
def broadcast_in_dim_transpose_rule(t, shape, broadcast_dimensions):
|
|
axes = tuple(onp.delete(range(len(shape)), broadcast_dimensions))
|
|
return [_reduce_sum(t, axes)]
|
|
|
|
def broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
|
broadcast_dimensions):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
new_shape = list(shape)
|
|
new_shape.insert(bdim, operand.shape[bdim])
|
|
new_broadcast_dimensions = [d if d < bdim else d + 1 for d in broadcast_dimensions]
|
|
new_broadcast_dimensions.insert(bdim, bdim)
|
|
return broadcast_in_dim(operand, new_shape, new_broadcast_dimensions), bdim
|
|
|
|
|
|
broadcast_in_dim_p = standard_primitive(
|
|
broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
|
ad.deflinear(broadcast_in_dim_p, broadcast_in_dim_transpose_rule)
|
|
batching.primitive_batchers[broadcast_in_dim_p] = broadcast_in_dim_batch_rule
|
|
|
|
|
|
def clamp_shape_rule(min, operand, max):
|
|
if min.shape and min.shape != operand.shape:
|
|
m = "clamp requires min.shape == operand.shape or min.shape == (), got {}."
|
|
raise TypeError(m.format(min.shape))
|
|
if max.shape and max.shape != operand.shape:
|
|
m = "clamp requires max.shape == operand.shape or max.shape == (), got {}."
|
|
raise TypeError(m.format(max.shape))
|
|
return operand.shape
|
|
|
|
clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
|
|
'clamp')
|
|
|
|
clamp_p = standard_primitive(clamp_shape_rule, clamp_dtype_rule, 'clamp')
|
|
ad.defjvp(clamp_p,
|
|
lambda g, min, operand, max:
|
|
select(bitwise_and(gt(min, operand), lt(min, max)),
|
|
_brcast(g, operand), _zeros(operand)),
|
|
lambda g, min, operand, max:
|
|
select(bitwise_and(gt(operand, min), lt(operand, max)),
|
|
g, _zeros(operand)),
|
|
lambda g, min, operand, max:
|
|
select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
|
|
|
|
|
|
def concatenate_shape_rule(*operands, **kwargs):
|
|
dimension = kwargs.pop('dimension')
|
|
if not operands:
|
|
msg = "concatenate expects at least one operand, got 0."
|
|
raise TypeError(msg)
|
|
if not all(isinstance(operand, UnshapedArray) for operand in operands):
|
|
msg = "All objects to concatenate must be arrays, got {}."
|
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
|
raise TypeError(msg.format(type(op)))
|
|
if len(set(operand.ndim for operand in operands)) != 1:
|
|
msg = "Cannot concatenate arrays with different ranks, got {}."
|
|
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
|
|
shapes = onp.array([operand.shape for operand in operands])
|
|
if not 0 <= dimension < shapes.shape[1]:
|
|
msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
if not onp.all(onp.delete(shapes[0] == shapes, dimension, axis=1)):
|
|
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
|
|
"other than the one being concatenated: dimension {} for shapes {}.")
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
|
|
concat_size = sum(o.shape[dimension] for o in operands)
|
|
ex_shape = operands[0].shape
|
|
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
|
|
|
|
def concatenate_dtype_rule(*operands, **kwargs):
|
|
_check_same_dtypes('concatenate', False, *(o.dtype for o in operands))
|
|
return operands[0].dtype
|
|
|
|
def concatenate_translation_rule(c, *operands, **kwargs):
|
|
dimension = kwargs.pop('dimension')
|
|
return c.Concatenate(operands, dimension=dimension)
|
|
|
|
def concatenate_transpose_rule(t, *operands, **kwargs):
|
|
dimension = kwargs.pop('dimension')
|
|
operand_shapes = kwargs.pop('operand_shapes')
|
|
limit_points = onp.cumsum([shape[dimension] for shape in operand_shapes])
|
|
|
|
starts = onp.zeros((len(operands), t.ndim))
|
|
starts[1:, dimension] = limit_points[:-1]
|
|
limits = onp.tile(t.shape, (len(operands), 1))
|
|
limits[:, dimension] = limit_points
|
|
|
|
return [slice(t, start, limit) if o is None else None
|
|
for o, start, limit in zip(operands, starts, limits)]
|
|
|
|
concatenate_p = standard_primitive(
|
|
concatenate_shape_rule, concatenate_dtype_rule, 'concatenate',
|
|
concatenate_translation_rule)
|
|
ad.deflinear(concatenate_p, concatenate_transpose_rule)
|
|
ad.primitive_transposes[concatenate_p] = concatenate_transpose_rule
|
|
|
|
|
|
def pad_shape_rule(operand, padding_value, padding_config):
|
|
if operand.dtype != padding_value.dtype:
|
|
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
|
raise TypeError(msg.format(operand.dtype, padding_value.dtype))
|
|
|
|
lo, hi, interior = zip(*padding_config)
|
|
out_shape = onp.add(onp.add(onp.add(lo, hi), operand.shape),
|
|
onp.multiply(interior, onp.subtract(operand.shape, 1)))
|
|
return tuple(out_shape)
|
|
|
|
def pad_transpose(t, operand, padding_value, padding_config):
|
|
lo, hi, interior = zip(*padding_config)
|
|
if onp.any(onp.less(lo, 0)) or onp.any(onp.less(hi, 0)):
|
|
msg = "pad transpose not implemented for negative padding, got {}."
|
|
raise NotImplementedError(msg.format(padding_config))
|
|
|
|
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
|
|
|
t_op = lambda: slice(t, lo, onp.subtract(t.shape, hi), onp.add(interior, 1))
|
|
t_operand = t_op() if operand is None else None
|
|
|
|
if padding_value is None:
|
|
t_operand = t_op() if t_operand is None else t_operand
|
|
t_padv = sub(total(t), total(t_operand))
|
|
else:
|
|
t_padv = None
|
|
|
|
return [t_operand, t_padv]
|
|
|
|
pad_p = standard_primitive(pad_shape_rule, _input_dtype, 'pad')
|
|
ad.deflinear(pad_p, pad_transpose)
|
|
ad.primitive_transposes[pad_p] = pad_transpose
|
|
|
|
|
|
def reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs):
|
|
if not onp.all(onp.greater_equal(new_sizes, 0)):
|
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
|
raise TypeError(msg.format(new_sizes))
|
|
if onp.prod(onp.shape(operand)) != onp.prod(new_sizes):
|
|
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
|
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
|
|
if dimensions is not None:
|
|
if set(dimensions) != set(range(onp.ndim(operand))):
|
|
msg = ('reshape dimensions must be a permutation of operand dimensions, '
|
|
'got dimensions {} for shape {}.')
|
|
raise TypeError(msg.format(dimensions, onp.shape(operand)))
|
|
return tuple(new_sizes)
|
|
|
|
def reshape_dtype_rule(operand, new_sizes, dimensions, **unused_kwargs):
|
|
return operand.dtype
|
|
|
|
def reshape_translation_rule(c, operand, new_sizes, dimensions, old_sizes):
|
|
del old_sizes # Unused.
|
|
return c.Reshape(operand, new_sizes=new_sizes, dimensions=dimensions)
|
|
|
|
def reshape_transpose_rule(t, new_sizes, dimensions, old_sizes):
|
|
out = reshape(t, old_sizes)
|
|
if dimensions is None:
|
|
return [out]
|
|
else:
|
|
return [transpose(out, onp.argsort(dimensions))]
|
|
|
|
def reshape_batch_rule(batched_args, batch_dims, new_sizes, dimensions, **unused):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
operand = batching.move_dim_to_front(operand, bdim)
|
|
if dimensions is not None:
|
|
raise NotImplementedError # TODO(mattjj): handle reshape w/ dimensions
|
|
dimensions = (0,) + tuple(onp.add(1, dimensions))
|
|
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
|
|
|
|
reshape_p = standard_primitive(reshape_shape_rule, reshape_dtype_rule,
|
|
'reshape', reshape_translation_rule)
|
|
ad.deflinear(reshape_p, reshape_transpose_rule)
|
|
batching.primitive_batchers[reshape_p] = reshape_batch_rule
|
|
|
|
|
|
def rev_shape_rule(operand, dimensions):
|
|
_check_shapelike('rev', 'dimensions', dimensions)
|
|
if len(set(dimensions)) != len(dimensions):
|
|
msg = 'rev dimensions must be unique, got {}.'
|
|
raise TypeError(msg.format(dimensions))
|
|
if not _max(dimensions) < operand.ndim:
|
|
msg = ('rev dimensions must all be less than operand ndim, got dimensions '
|
|
'{} for operand ndim {}.')
|
|
raise TypeError(msg.format(dimensions, operand.ndim))
|
|
return operand.shape
|
|
|
|
rev_p = standard_primitive(rev_shape_rule, _input_dtype, 'rev')
|
|
ad.deflinear(rev_p, lambda t, dimensions: [rev(t, dimensions)])
|
|
|
|
|
|
def transpose_shape_rule(operand, permutation):
|
|
if not isinstance(permutation, (tuple, list, onp.ndarray)):
|
|
msg = "transpose permutation must be a tuple/list/ndarray, got {}."
|
|
raise TypeError(msg.format(type(permutation)))
|
|
if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
|
|
msg = ("transpose permutation isn't a permutation of operand dimensions, "
|
|
"got permutation {} for operand shape {}.")
|
|
raise TypeError(msg.format(permutation, operand.shape))
|
|
return tuple(onp.take(operand.shape, permutation))
|
|
|
|
def transpose_batch_rule(batched_args, batch_dims, permutation):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
perm = tuple(onp.insert(onp.add(permutation, 1), bdim, 0))
|
|
return transpose(operand, perm), 0
|
|
|
|
transpose_p = standard_primitive(transpose_shape_rule, _input_dtype,
|
|
'transpose')
|
|
ad.deflinear(transpose_p,
|
|
lambda t, permutation: [transpose(t, onp.argsort(permutation))])
|
|
batching.primitive_batchers[transpose_p] = transpose_batch_rule
|
|
|
|
|
|
def select_shape_rule(pred, on_true, on_false):
|
|
if on_true.shape != on_false.shape:
|
|
msg = "select on_true and on_false must have the same shape, got {} and {}."
|
|
raise TypeError(msg.format(on_true.shape, on_false.shape))
|
|
if pred.shape and pred.shape != on_true.shape:
|
|
msg = ("select pred must be scalar or have the same shape as on_true and "
|
|
"on_false, got pred shape {} for on_true and on_false of shape {}.")
|
|
raise TypeError(msg.format(pred.shape, on_true.shape))
|
|
return on_true.shape
|
|
|
|
def select_dtype_rule(pred, on_true, on_false):
|
|
_check_same_dtypes("select", False, on_true.dtype, on_false.dtype)
|
|
if not onp.issubdtype(pred.dtype, onp.bool_):
|
|
msg = "select pred must be boolean type, got {}."
|
|
raise TypeError(msg.format(pred.dtype))
|
|
return on_true.dtype
|
|
|
|
def select_transpose_rule(t, pred, on_true, on_false):
|
|
return [None,
|
|
select(pred, t, _zeros(on_false)) if on_true is None else None,
|
|
select(pred, _zeros(on_true), t) if on_false is None else None]
|
|
|
|
select_p = standard_primitive(select_shape_rule, select_dtype_rule, 'select')
|
|
ad.defjvp(select_p,
|
|
None,
|
|
lambda g, b, x, y: select(b, g, _zeros(g)),
|
|
lambda g, b, x, y: select(b, _zeros(g), g))
|
|
ad.primitive_transposes[select_p] = select_transpose_rule
|
|
|
|
|
|
def slice_shape_rule(operand, start_indices, limit_indices, strides,
|
|
operand_shape):
|
|
_check_shapelike("slice", "start_indices", start_indices)
|
|
_check_shapelike("slice", "limit_indices", limit_indices)
|
|
if operand.ndim != len(start_indices):
|
|
msg = ("slice start_indices must have length equal to the number of "
|
|
"dimensions of the operand, got indices {} for operand shape {}.")
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
|
if len(start_indices) != len(limit_indices):
|
|
msg = ("slice limit_indices must have the same length as start_indices, "
|
|
"got start_inidices {} and limit_indices {}.")
|
|
raise TypeError(msg.format(start_indices, limit_indices))
|
|
if not onp.all(onp.less_equal(limit_indices, operand.shape)):
|
|
msg = ("slice limit_indices must be less than or equal to operand shape, "
|
|
"got limit_indices {} for operand shape {}.")
|
|
raise TypeError(msg.format(limit_indices, operand.shape))
|
|
if not onp.all(onp.greater_equal(start_indices, 0)):
|
|
msg = ("slice start_indices must be greater than or equal to zero, "
|
|
"got start_indices of {}.")
|
|
raise TypeError(msg.format(start_indices))
|
|
if not onp.all(onp.greater_equal(limit_indices, start_indices)):
|
|
msg = ("slice limit_indices must be greater than or equal to start_indices,"
|
|
" got start_indices {} and limit_indices {}.")
|
|
raise TypeError(msg.format(start_indices, limit_indices))
|
|
if strides is None:
|
|
strides = onp.ones(operand.ndim, onp.int32)
|
|
else:
|
|
_check_shapelike("slice", "strides", strides)
|
|
if len(strides) != operand.ndim:
|
|
msg = ("slice strides must have length equal to the number of dimensions "
|
|
"of the operand, got strides {} for operand shape {}.")
|
|
raise TypeError(msg.format(strides, operand.shape))
|
|
if not onp.all(onp.greater(strides, 0)):
|
|
msg = "slice strides must be positive, got {}"
|
|
raise TypeError(msg.format(strides))
|
|
|
|
result_shape = onp.divide(onp.add(onp.subtract(limit_indices, start_indices),
|
|
strides) - 1,
|
|
strides)
|
|
return tuple(result_shape)
|
|
|
|
def slice_translation_rule(c, operand, start_indices, limit_indices, strides,
|
|
operand_shape):
|
|
return c.Slice(operand, start_indices, limit_indices, strides)
|
|
|
|
def slice_transpose_rule(t, start_indices, limit_indices, strides,
|
|
operand_shape):
|
|
if strides is None or onp.all(onp.equal(strides, 1)):
|
|
pads = zip(start_indices, onp.subtract(operand_shape, limit_indices),
|
|
(0,) * len(start_indices))
|
|
else:
|
|
real_limits = onp.add(onp.add(start_indices, 1),
|
|
onp.multiply(onp.subtract(t.shape, 1), strides))
|
|
pads = zip(start_indices, onp.subtract(operand_shape, real_limits),
|
|
onp.subtract(strides, 1))
|
|
result = pad(t, _const(t, 0), pads)
|
|
assert result.shape == operand_shape
|
|
return [result]
|
|
|
|
def slice_batching_rule(batched_args, batch_dims, start_indices, limit_indices,
|
|
strides, **unused_kwargs):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
|
|
new_start_indices = list(start_indices)
|
|
new_start_indices.insert(bdim, 0)
|
|
|
|
new_limit_indices = list(limit_indices)
|
|
new_limit_indices.insert(bdim, operand.shape[bdim])
|
|
|
|
if strides is None:
|
|
new_strides = None
|
|
else:
|
|
new_strides = list(strides)
|
|
new_strides.insert(bdim, 1)
|
|
|
|
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
|
|
return out, bdim
|
|
|
|
slice_p = standard_primitive(slice_shape_rule, _input_dtype, 'slice',
|
|
slice_translation_rule)
|
|
ad.deflinear(slice_p, slice_transpose_rule)
|
|
batching.primitive_batchers[slice_p] = slice_batching_rule
|
|
|
|
|
|
def dynamic_slice_shape_rule(operand, start_indices, slice_sizes,
|
|
operand_shape):
|
|
if operand.ndim != len(start_indices):
|
|
msg = ("dynamic_slice start_indices must have length equal to the number "
|
|
"of dimensions of the operand, got indices {} for operand shape {}.")
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
|
if len(start_indices) != len(slice_sizes):
|
|
msg = ("dynamic_slice slice_sizes must have the same length as "
|
|
"start_indices, got start_inidices length {} and slice_sizes {}.")
|
|
raise TypeError(msg.format(len(start_indices), slice_sizes))
|
|
if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
|
|
msg = ("slice slice_sizes must be less than or equal to operand shape, "
|
|
"got slice_sizes {} for operand shape {}.")
|
|
raise TypeError(msg.format(slice_sizes, operand.shape))
|
|
if not onp.all(onp.greater_equal(slice_sizes, 0)):
|
|
msg = ("slice slice_sizes must be greater than or equal to zero, "
|
|
"got slice_sizes of {}.")
|
|
raise TypeError(msg.format(slice_sizes))
|
|
return tuple(slice_sizes)
|
|
|
|
def dynamic_slice_translation_rule(c, operand, start_indices, slice_sizes,
|
|
operand_shape):
|
|
return c.DynamicSlice(operand, start_indices, slice_sizes)
|
|
|
|
def dynamic_slice_jvp_rule(g, operand, start_indices, slice_sizes,
|
|
operand_shape):
|
|
return dynamic_slice(g, start_indices, slice_sizes)
|
|
|
|
def dynamic_slice_transpose_rule(t, operand, start_indices, slice_sizes,
|
|
operand_shape):
|
|
assert operand is None
|
|
zeros = broadcast(_const(t, 0), operand_shape)
|
|
return [dynamic_update_slice(zeros, t, start_indices)]
|
|
|
|
dynamic_slice_p = standard_primitive(
|
|
dynamic_slice_shape_rule, _input_dtype, 'dynamic_slice',
|
|
dynamic_slice_translation_rule)
|
|
ad.defjvp(dynamic_slice_p, dynamic_slice_jvp_rule, None)
|
|
ad.primitive_transposes[dynamic_slice_p] = dynamic_slice_transpose_rule
|
|
|
|
|
|
def dynamic_update_slice_shape_rule(operand, update, start_indices,
|
|
update_shape):
|
|
if operand.ndim != update.ndim:
|
|
msg = ("dynamic_update_slice update must have the same rank as operand, "
|
|
"got update shape {} for operand shape {}.")
|
|
raise TypeError(msg.format(update.shape, operand.shape))
|
|
if operand.ndim != len(start_indices):
|
|
msg = ("dynamic_update_slice start_indices must have length equal to the "
|
|
"rank of operand, got indices {} for operand shape {}.")
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
|
if not onp.all(onp.less_equal(update.shape, operand.shape)):
|
|
msg = ("dynamic_update_slice update shape must be smaller than operand "
|
|
"shape, got update shape {} for operand shape {}.")
|
|
raise TypeError(msg.format(update.shape, operand.shape))
|
|
return operand.shape
|
|
|
|
def dynamic_update_slice_dtype_rule(operand, update, start_indices,
|
|
update_shape):
|
|
_check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype)
|
|
return operand.dtype
|
|
|
|
def dynamic_update_slice_jvp(primals, tangents, update_shape):
|
|
operand, update, start_indices = primals
|
|
g_operand, g_update, g_start_indices = tangents
|
|
assert g_start_indices is ad_util.zero
|
|
val_out = dynamic_update_slice(operand, update, start_indices)
|
|
tangent_out = dynamic_update_slice(g_operand, g_update, start_indices)
|
|
return val_out, tangent_out
|
|
|
|
def dynamic_update_slice_transpose_rule(t, operand, update, start_indices,
|
|
update_shape):
|
|
assert start_indices is not None
|
|
dus = dynamic_update_slice
|
|
ds = dynamic_slice
|
|
zeros = _zeros(t, shape=update_shape)
|
|
operand_t = dus(t, zeros, start_indices) if operand is None else None
|
|
update_t = ds(t, start_indices, update_shape) if update is None else None
|
|
return [operand_t, update_t, None]
|
|
|
|
def dynamic_update_slice_translation_rule(c, operand, update, start_indices,
|
|
update_shape):
|
|
return c.DynamicUpdateSlice(operand, update, start_indices)
|
|
|
|
dynamic_update_slice_p = standard_primitive(
|
|
dynamic_update_slice_shape_rule, dynamic_update_slice_dtype_rule,
|
|
'dynamic_update_slice', dynamic_update_slice_translation_rule)
|
|
ad.primitive_jvps[dynamic_update_slice_p] = dynamic_update_slice_jvp
|
|
ad.primitive_transposes[dynamic_update_slice_p] = \
|
|
dynamic_update_slice_transpose_rule
|
|
|
|
|
|
def index_take_shape_rule(src, *idxs, **kwargs):
|
|
axes = kwargs['axes']
|
|
return (idxs[0].shape[0],) + tuple(onp.delete(src.shape, axes))
|
|
|
|
def index_take_translation_rule(c, src, *idxs, **kwargs):
|
|
jaxpr = kwargs['jaxpr']
|
|
consts = kwargs['consts']
|
|
shapes = map(c.GetShape, (src,) + idxs)
|
|
xla_computation = xla.jaxpr_computation(jaxpr, consts, (), *shapes)
|
|
return c.Call(xla_computation, (src,) + idxs)
|
|
|
|
def index_take_jvp(primals, tangents, axes, input_shape, jaxpr, consts):
|
|
src = primals[0]
|
|
idxs = tuple(primals[1:])
|
|
g =tangents[0]
|
|
return index_take(src, idxs, axes), index_take(g, idxs, axes)
|
|
|
|
def index_take_transpose_rule(t, src, *idxs, **kwargs):
|
|
assert src is None
|
|
axes = kwargs['axes']
|
|
input_shape = kwargs['input_shape']
|
|
t_src = index_untake(t, _zeros(t, shape=input_shape), idxs, axes)
|
|
return [t_src] + [None] * len(idxs)
|
|
|
|
index_take_p = standard_primitive(index_take_shape_rule, _input_dtype,
|
|
'index_take', index_take_translation_rule)
|
|
ad.primitive_jvps[index_take_p] = index_take_jvp
|
|
ad.primitive_transposes[index_take_p] = index_take_transpose_rule
|
|
|
|
|
|
def index_untake_shape_rule(src, dst, *idxs, **kwargs):
|
|
return dst.shape
|
|
|
|
def index_untake_translation_rule(c, src, dst, *idxs, **kwargs):
|
|
jaxpr = kwargs['jaxpr']
|
|
consts = kwargs['consts']
|
|
shapes = map(c.GetShape, (src, dst) + idxs)
|
|
xla_computation = xla.jaxpr_computation(jaxpr, consts, (), *shapes)
|
|
return c.Call(xla_computation, (src, dst) + idxs)
|
|
|
|
def index_untake_jvp(primals, tangents, axes, jaxpr, consts):
|
|
src, dst = primals[0], primals[1]
|
|
idxs = tuple(primals[2:])
|
|
g_src, g_dst = tangents[0], tangents[1]
|
|
val_out = index_untake(src, dst, idxs, axes)
|
|
tangent_out = index_untake(g_src, g_dst, idxs, axes)
|
|
return val_out, tangent_out
|
|
|
|
def index_untake_transpose_rule(t, src, dst, *idxs, **kwargs):
|
|
axes = kwargs['axes']
|
|
if src is None:
|
|
t_src = index_take(t, idxs, axes)
|
|
if dst is None:
|
|
t_dst = t
|
|
return [t_src, t_dst] + [None] * len(idxs)
|
|
|
|
index_untake_p = standard_primitive(
|
|
index_untake_shape_rule, _input_dtype, 'index_untake',
|
|
index_untake_translation_rule)
|
|
ad.primitive_jvps[index_untake_p] = index_untake_jvp
|
|
ad.primitive_transposes[index_untake_p] = index_untake_transpose_rule
|
|
|
|
|
|
def reduce_shape_rule(operand, init_value, jaxpr, consts, dimensions):
|
|
return tuple(onp.delete(operand.shape, dimensions))
|
|
|
|
def reduce_translation_rule(c, operand, init_value, jaxpr, consts, dimensions):
|
|
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
|
return c.Reduce(operand, init_value, xla_computation, dimensions)
|
|
|
|
def _reduction_computation(c, jaxpr, consts, init_value):
|
|
shape = c.GetShape(init_value)
|
|
return xla.jaxpr_computation(jaxpr, consts, (), shape, shape)
|
|
|
|
reduce_p = standard_primitive(reduce_shape_rule, _input_dtype, 'reduce',
|
|
reduce_translation_rule)
|
|
batching.defreducer(reduce_p)
|
|
|
|
|
|
def reduce_sum_shape_rule(operand, axes, input_shape):
|
|
return tuple(onp.delete(operand.shape, axes))
|
|
|
|
def reduce_sum_translation_rule(c, operand, axes, input_shape):
|
|
dtype = c.GetShape(operand).numpy_dtype()
|
|
scalar = xla_bridge.Shape.array_shape(dtype, ())
|
|
return c.Reduce(operand, c.Constant(onp.array(0, dtype)),
|
|
xla.primitive_computation(add_p, scalar, scalar),
|
|
axes)
|
|
|
|
def reduce_sum_transpose_rule(cotangent, input_shape, axes):
|
|
broadcast_dimensions = tuple(onp.delete(onp.arange(len(input_shape)), axes))
|
|
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
|
assert result.shape == input_shape
|
|
return [result]
|
|
|
|
reduce_sum_p = standard_primitive(reduce_sum_shape_rule, _input_dtype,
|
|
'reduce_sum', reduce_sum_translation_rule)
|
|
ad.deflinear(reduce_sum_p, reduce_sum_transpose_rule)
|
|
batching.defreducer(reduce_sum_p)
|
|
|
|
|
|
def reduce_chooser_shape_rule(operand, axes):
|
|
return tuple(onp.delete(operand.shape, axes))
|
|
|
|
def reduce_chooser_translation_rule(prim, identity, c, operand, axes):
|
|
dtype = c.GetShape(operand).numpy_dtype()
|
|
scalar = xla_bridge.Shape.array_shape(dtype, ())
|
|
return c.Reduce(operand, c.Constant(identity(dtype)),
|
|
xla.primitive_computation(prim, scalar, scalar), axes)
|
|
|
|
def reduce_chooser_jvp_rule(g, ans, operand, axes):
|
|
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
|
# locations in a single pass (rather than comparing equality) and use a
|
|
# gather, and/or even push along the chosen elements of g (b/112040122)
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
location_indicators = convert_element_type(
|
|
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
|
counts = _reduce_sum(location_indicators, axes)
|
|
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
|
|
|
reduce_max_translation_rule = partial(reduce_chooser_translation_rule, max_p,
|
|
_get_max_identity)
|
|
reduce_max_p = standard_primitive(reduce_chooser_shape_rule, _input_dtype,
|
|
'reduce_max', reduce_max_translation_rule)
|
|
ad.defjvp2(reduce_max_p, reduce_chooser_jvp_rule)
|
|
batching.defreducer(reduce_max_p)
|
|
|
|
|
|
reduce_min_translation_rule = partial(
|
|
reduce_chooser_translation_rule, min_p, _get_min_identity)
|
|
reduce_min_p = standard_primitive(reduce_chooser_shape_rule, _input_dtype,
|
|
'reduce_min', reduce_min_translation_rule)
|
|
ad.defjvp2(reduce_min_p, reduce_chooser_jvp_rule)
|
|
batching.defreducer(reduce_min_p)
|
|
|
|
|
|
def reduce_window_shape_rule(operand, init_value, jaxpr, consts,
|
|
window_dimensions, window_strides, padding):
|
|
if operand.dtype != init_value.dtype:
|
|
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
|
|
" got operand dtype {} and init_value dtype {}.")
|
|
raise TypeError(msg.format(operand.dtype, init_value.dtype))
|
|
return common_reduce_window_shape_rule(operand, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
def reduce_window_translation_rule(c, operand, init_value, jaxpr, consts,
|
|
window_dimensions, window_strides, padding):
|
|
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
|
return c.ReduceWindow(operand, init_value, xla_computation, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
reduce_window_p = standard_primitive(
|
|
reduce_window_shape_rule, _input_dtype, 'reduce_window',
|
|
reduce_window_translation_rule)
|
|
|
|
|
|
def reduce_window_sum_shape_rule(operand, window_dimensions, window_strides,
|
|
padding, input_shape):
|
|
return common_reduce_window_shape_rule(operand, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
def reduce_window_sum_translation_rule(c, operand, window_dimensions,
|
|
window_strides, padding, input_shape):
|
|
dtype = c.GetShape(operand).numpy_dtype()
|
|
scalar = xla_bridge.Shape.array_shape(dtype, ())
|
|
return c.ReduceWindow(operand, c.Constant(onp.array(0, dtype)),
|
|
xla.primitive_computation(add_p, scalar, scalar),
|
|
window_dimensions, window_strides, padding)
|
|
|
|
def reduce_window_sum_transpose_rule(cotangent, window_dimensions,
|
|
window_strides, padding, input_shape):
|
|
in_pads = padtype_to_pads(input_shape, window_dimensions, window_strides,
|
|
padding)
|
|
ones = [1] * len(input_shape)
|
|
pads = _conv_general_vjp_lhs_padding(
|
|
input_shape, window_dimensions, window_strides, cotangent.shape, in_pads,
|
|
ones, ones)
|
|
padding_config = [(lo, hi, stride - 1)
|
|
for (lo, hi), stride in zip(pads, window_strides)]
|
|
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
|
|
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
|
|
xla_bridge.get_xla_client().PaddingType.VALID)
|
|
assert result.shape == input_shape
|
|
return [result]
|
|
|
|
reduce_window_sum_p = standard_primitive(
|
|
reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
|
|
reduce_window_sum_translation_rule)
|
|
ad.deflinear(reduce_window_sum_p, reduce_window_sum_transpose_rule)
|
|
|
|
|
|
def reduce_window_chooser_translation_rule(
|
|
prim, identity, c, operand, window_dimensions, window_strides, padding):
|
|
dtype = c.GetShape(operand).numpy_dtype()
|
|
scalar = xla_bridge.Shape.array_shape(dtype, ())
|
|
return c.ReduceWindow(operand, c.Constant(identity(dtype)),
|
|
xla.primitive_computation(prim, scalar, scalar),
|
|
window_dimensions, window_strides, padding)
|
|
|
|
def reduce_window_chooser_jvp_rule(prim, g, operand, window_dimensions,
|
|
window_strides, padding):
|
|
assert prim is max_p or prim is min_p
|
|
select_prim = ge_p if prim is max_p else le_p
|
|
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
|
|
def common_reduce_window_shape_rule(operand, window_dimensions, window_strides,
|
|
padding):
|
|
_check_shapelike("reduce_window", "window_dimensions", window_dimensions)
|
|
_check_shapelike("reduce_window", "window_strides", window_strides)
|
|
if operand.ndim != len(window_dimensions):
|
|
msg = ("reduce_window got the wrong number of window_dimensions for "
|
|
"operand: got operand shape {} with window_dimensions {}.")
|
|
raise TypeError(msg.format(operand.shape, window_dimensions))
|
|
if len(window_strides) != len(window_dimensions):
|
|
msg = ("reduce_window got inconsistent window_strides and "
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
|
|
|
return reduce_window_shape_tuple(operand.shape, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
|
padding):
|
|
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
|
|
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
|
|
t = onp.divide(onp.subtract(operand_padded, window_dimensions), window_strides) + 1
|
|
return tuple(t)
|
|
|
|
|
|
reduce_window_max_translation_rule = partial(
|
|
reduce_window_chooser_translation_rule, max_p, _get_max_identity)
|
|
reduce_window_max_p = standard_primitive(
|
|
common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
|
|
reduce_window_max_translation_rule)
|
|
ad.defjvp(reduce_window_max_p, partial(reduce_window_chooser_jvp_rule, max_p))
|
|
|
|
|
|
reduce_window_min_translation_rule = partial(
|
|
reduce_window_chooser_translation_rule, min_p, _get_min_identity)
|
|
reduce_window_min_p = standard_primitive(
|
|
common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min',
|
|
reduce_window_min_translation_rule)
|
|
ad.defjvp(reduce_window_min_p, partial(reduce_window_chooser_jvp_rule, min_p))
|
|
|
|
|
|
def select_and_scatter_shape_rule(
|
|
operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
|
|
scatter_consts, window_dimensions, window_strides, padding):
|
|
_check_shapelike("select_and_scatter", "window_dimensions", window_dimensions)
|
|
_check_shapelike("select_and_scatter", "window_strides", window_strides)
|
|
if len(window_dimensions) != len(window_strides):
|
|
msg = ("select_and_scatter got inconsistent window_strides and "
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
|
return operand.shape
|
|
|
|
def select_and_scatter_translation(operand, source, init_value, select_jaxpr,
|
|
select_consts, scatter_jaxpr, scatter_consts,
|
|
window_dimensions, window_strides, padding):
|
|
select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
|
|
scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
|
|
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
|
padding, source, init_value, scatter)
|
|
|
|
select_and_scatter_p = standard_primitive(
|
|
select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
|
|
select_and_scatter_translation)
|
|
|
|
|
|
def select_and_scatter_add_shape_rule(
|
|
source, operand, select_prim, window_dimensions, window_strides, padding):
|
|
return operand.shape
|
|
|
|
def select_and_scatter_add_translation(
|
|
c, source, operand, select_prim, window_dimensions, window_strides,
|
|
padding):
|
|
dtype = c.GetShape(operand).numpy_dtype()
|
|
scalar = xla_bridge.Shape.array_shape(dtype, ())
|
|
select = xla.primitive_computation(select_prim, scalar, scalar)
|
|
scatter = xla.primitive_computation(add_p, scalar, scalar)
|
|
zero = c.Constant(onp.array(0, dtype))
|
|
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
|
padding, source, zero, scatter)
|
|
|
|
def select_and_scatter_add_transpose(
|
|
t, source, operand, select_prim, window_dimensions, window_strides,
|
|
padding):
|
|
assert source is None and operand is not None
|
|
result = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
|
window_strides, padding)
|
|
return [result, None]
|
|
|
|
select_and_scatter_add_p = standard_primitive(
|
|
select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
|
|
select_and_scatter_add_translation)
|
|
ad.primitive_transposes[select_and_scatter_add_p] = \
|
|
select_and_scatter_add_transpose
|
|
|
|
|
|
def select_and_gather_add_shape_rule(
|
|
tangents, operand, select_prim, window_dimensions, window_strides, padding):
|
|
if tangents.shape != operand.shape:
|
|
msg = ("select_and_gather_add tangents and operand shapes must match, "
|
|
"got {} and {}.")
|
|
raise TypeError(msg.format(tangents.shape, operand.shape))
|
|
return common_reduce_window_shape_rule(operand, window_dimensions,
|
|
window_strides, padding)
|
|
|
|
def select_and_gather_add_translation(
|
|
c, tangents, operand, select_prim, window_dimensions, window_strides,
|
|
padding):
|
|
raise NotImplementedError("No efficient translation.")
|
|
|
|
def select_and_gather_add_transpose(
|
|
t, tangents, operand, select_prim, window_dimensions, window_strides,
|
|
padding):
|
|
assert tangents is None and operand is not None
|
|
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
|
|
window_strides, padding)
|
|
return [result, None]
|
|
|
|
select_and_gather_add_p = standard_primitive(
|
|
select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
|
|
select_and_gather_add_translation)
|
|
ad.primitive_transposes[select_and_gather_add_p] = \
|
|
select_and_gather_add_transpose
|
|
|
|
|
|
sort_shape = lambda operand, dimension: operand.shape
|
|
|
|
def sort_jvp_rule(g, operand, dimension):
|
|
_, g_out = sort_key_val(operand, g, dimension)
|
|
return g_out
|
|
|
|
sort_p = standard_primitive(sort_shape, _input_dtype, 'sort')
|
|
ad.defjvp(sort_p, sort_jvp_rule)
|
|
|
|
|
|
def sort_key_val_abstract_eval(keys, values, dimension):
|
|
return core.AbstractTuple((keys, values))
|
|
|
|
def sort_key_val_impl(keys, values, dimension):
|
|
out = xla.apply_primitive(sort_key_val_p, keys, values, dimension=dimension)
|
|
sorted_keys, sorted_values = out
|
|
return core.pack((sorted_keys, sorted_values))
|
|
|
|
def sort_key_val_jvp(primals, tangents, dimension):
|
|
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
|
|
# sort_key_val, or if we could apply a fixed permutation efficiently, we could
|
|
# implement this jvp rule with a single sort. The apply_permutation primitive
|
|
# would make the jvp (and corresponding transpose rule) faster and easier.
|
|
# This would also be cleaner if we didn't get the sorted keys out.
|
|
# TODO(mattjj): make sort_key_val variadic, no sorted keys out by default
|
|
keys, values = primals
|
|
keys_tangents, values_tangents = tangents
|
|
|
|
val_out = sort_key_val(keys, values, dimension)
|
|
|
|
keys_tangents_out = sort_jvp_rule(keys_tangents, keys, dimension)
|
|
values_tangents_out = sort_jvp_rule(values_tangents, keys, dimension)
|
|
tangents_out = keys_tangents_out, values_tangents_out
|
|
|
|
return core.pack(val_out), core.pack(tangents_out)
|
|
|
|
def sort_key_val_transpose_rule(t, keys, values, dimension):
|
|
t_keys, t_values = t
|
|
assert t_keys is ad_util.zero
|
|
broadcasted_iota = broadcast_in_dim(
|
|
onp.arange(keys.shape[dimension]), keys.shape, [dimension % keys.ndim])
|
|
_, perm = sort_key_val(keys, broadcasted_iota)
|
|
keys_result = ad_util.zero if keys is None else None
|
|
values_result = sort_key_val(perm, t_values)[1] if values is None else None
|
|
return [keys_result, values_result]
|
|
|
|
sort_key_val_p = Primitive('sort_key_val')
|
|
sort_key_val_p.def_impl(sort_key_val_impl)
|
|
sort_key_val_p.def_abstract_eval(sort_key_val_abstract_eval)
|
|
xla.translations[sort_key_val_p] = partial(standard_translate, 'sort_key_val')
|
|
ad.primitive_jvps[sort_key_val_p] = sort_key_val_jvp
|
|
ad.primitive_transposes[sort_key_val_p] = sort_key_val_transpose_rule
|
|
|
|
|
|
def while_loop_abstract_eval(init_val, opaque_params):
|
|
abs_out = opaque_params.val[0]
|
|
return maybe_tracer_tuple_to_abstract_tuple(abs_out)
|
|
|
|
def while_loop_translation_rule(c, init_val, opaque_params):
|
|
shape = c.GetShape(init_val)
|
|
abs_out, cond_jaxpr, cond_consts, body_jaxpr, body_consts = opaque_params.val
|
|
cond_computation = xla.jaxpr_computation(cond_jaxpr, cond_consts, (), shape)
|
|
body_computation = xla.jaxpr_computation(body_jaxpr, body_consts, (), shape)
|
|
return c.While(cond_computation, body_computation, init_val)
|
|
|
|
while_p = Primitive('while')
|
|
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
|
while_p.def_abstract_eval(while_loop_abstract_eval)
|
|
xla.translations[while_p] = while_loop_translation_rule
|
|
|
|
|
|
### util
|
|
|
|
|
|
def _dilate_shape(shape, dilation):
|
|
"""Utility function for computing the shape resulting from a dilation."""
|
|
if not onp.all(onp.greater(dilation, 0)):
|
|
msg = "All dilations must be positive, got {}."
|
|
raise TypeError(msg.format(dilation))
|
|
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
|
|
return onp.multiply(dilation, onp.subtract(shape, 1)) + 1
|
|
|
|
|
|
|
|
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
|
"""Convert padding string to list of pairs of pad values."""
|
|
PaddingType = xla_bridge.get_xla_client().PaddingType
|
|
|
|
if isinstance(padding, str):
|
|
mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}
|
|
try:
|
|
padding = mapping[padding.upper()]
|
|
except KeyError:
|
|
msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
|
|
raise RuntimeError(msg.format(padding))
|
|
|
|
if padding == PaddingType.SAME:
|
|
out_shape = onp.ceil(onp.true_divide(in_shape, window_strides)).astype(int)
|
|
pad_sizes = [_max((out_size - 1) * stride + window_shape - in_size, 0)
|
|
for out_size, stride, window_shape, in_size
|
|
in zip(out_shape, window_strides, window_shape, in_shape)]
|
|
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
|
|
elif padding == PaddingType.VALID:
|
|
return [(0, 0)] * len(in_shape)
|
|
else:
|
|
msg = "Unknown padding type: {}."
|
|
raise TypeError(msg.format(padding))
|
|
|
|
|
|
def _check_same_dtypes(name, ignore_fp_precision, *dtypes):
|
|
"""Check that dtypes agree, possibly ignoring float precision."""
|
|
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
|
# allows mixed floating point precision, but the HLO verifier often rejects it
|
|
dtypes = map(onp.dtype, dtypes) # canonicalize
|
|
if ignore_fp_precision:
|
|
dtypes = [
|
|
onp.floating if onp.issubdtype(dtype, onp.floating)
|
|
else onp.complexfloating if onp.issubdtype(dtype, onp.complexfloating)
|
|
else dtype for dtype in dtypes]
|
|
if len({xla_bridge.canonicalize_dtype(t) for t in dtypes}) != 1:
|
|
if ignore_fp_precision:
|
|
msg = ("{} requires arguments to have same dtypes up to floating point "
|
|
"precision, got {}.")
|
|
else:
|
|
msg = "{} requires arguments to have the same dtypes, got {}."
|
|
raise TypeError(msg.format(name, ", ".join(map(str, dtypes))))
|
|
|
|
|
|
def _check_conv_shapes(fun_name, lhs_shape, rhs_shape, window_strides):
|
|
"""Check that conv shapes are valid and are consistent with window_strides."""
|
|
if len(lhs_shape) != len(rhs_shape):
|
|
msg = "Arguments to {} must have same rank, got {} and {}."
|
|
raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
|
|
if len(lhs_shape) < 2:
|
|
msg = "Arguments to {} must have rank at least 2, got {} and {}."
|
|
raise TypeError(msg.format(fun_name, len(lhs_shape), len(rhs_shape)))
|
|
if lhs_shape[1] != rhs_shape[1]:
|
|
msg = "Arguments to {} must agree on input feature size, got {} and {}."
|
|
raise TypeError(msg.format(fun_name, lhs_shape[1], rhs_shape[1]))
|
|
_check_shapelike(fun_name, "window_strides", window_strides)
|
|
if not onp.all(onp.greater(window_strides, 0)):
|
|
msg = "All elements of window_strides must be positive, got {}."
|
|
raise TypeError(msg.format(window_strides))
|
|
if len(window_strides) != len(lhs_shape) - 2:
|
|
msg = "{} window_strides has wrong length: expected {}, got {}."
|
|
expected_length = len(lhs_shape) - 2
|
|
raise TypeError(msg.format(fun_name, expected_length, len(window_strides)))
|
|
|
|
|
|
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
|
"""Compute the shape tuple of a conv given input shapes in canonical order."""
|
|
if isinstance(pads, str):
|
|
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
|
|
if len(pads) != len(lhs_shape) - 2:
|
|
msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
|
|
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
|
|
|
|
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
|
|
out_space = onp.divide(onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
|
out_space = onp.maximum(0, out_space)
|
|
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
|
|
return tuple(out_shape)
|
|
|
|
|
|
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
|
dimension_numbers):
|
|
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
|
lhs_trans = onp.take(lhs_shape, lhs_perm)
|
|
rhs_trans = onp.take(rhs_shape, rhs_perm)
|
|
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
|
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
|
|
|
|
|
|
def _check_shapelike(fun_name, arg_name, obj):
|
|
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
|
|
if not isinstance(obj, (tuple, list, onp.ndarray)):
|
|
msg = "{} {} must be of type tuple/list/ndarray, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
|
|
# bool(obj) for an ndarray raises an error, so we check len
|
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
|
return
|
|
obj_arr = onp.array(obj)
|
|
if obj_arr.ndim != 1:
|
|
msg = "{} {} must be rank 1, got {}."
|
|
raise TypeError(msg.format(obj_arr.ndim))
|
|
if not onp.issubdtype(obj_arr.dtype, onp.integer):
|
|
msg = "{} {} must have every element be an integer type, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj))))
|
|
if not (obj_arr >= 0).all():
|
|
msg = "{} {} must have every element be nonnegative, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, obj))
|
|
|
|
|
|
def conv_general_permutations(dimension_numbers):
|
|
"""Utility for convolution dimension permutations relative to Conv HLO."""
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C")
|
|
for i, (a, b) in enumerate(charpairs):
|
|
if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1:
|
|
msg = ("convolution dimension_numbers[{}] must contain the characters "
|
|
"'{}' and '{}' exatly once, got {}.")
|
|
raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
|
|
if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
|
|
msg = ("convolution dimension_numbers[{}] cannot have duplicate "
|
|
"characters, got {}.")
|
|
raise TypeError(msg.format(i, dimension_numbers[i]))
|
|
if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) ==
|
|
set(out_spec) - set(out_char)):
|
|
msg = ("convolution dimension_numbers elements must each have the same "
|
|
"set of spatial characters, got {}.")
|
|
raise TypeError(msg.format(dimension_numbers))
|
|
|
|
def getperm(spec, charpair):
|
|
spatial = (i for i, c in enumerate(spec) if c not in charpair)
|
|
if spec is not rhs_spec:
|
|
spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i]))
|
|
return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial)
|
|
|
|
lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs)
|
|
return lhs_perm, rhs_perm, out_perm
|
|
|
|
|
|
def _dynamic_slice_indices(operand, start_indices):
|
|
if isinstance(start_indices, (tuple, list)):
|
|
start_indices = concatenate([reshape(i, [1]) for i in start_indices], 0)
|
|
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
|
|
|
|
|
|
_const = lambda example, val: onp.array(val, _dtype(example))
|
|
_zeros = partial(full_like, fill_value=0)
|
|
_zero = partial(full_like, shape=(), fill_value=0)
|
|
_ones = partial(full_like, fill_value=1)
|
|
_one = partial(full_like, shape=(), fill_value=1)
|
|
_twos = partial(full_like, fill_value=2)
|
|
_two = partial(full_like, shape=(), fill_value=2)
|
|
|
|
_dtype = onp.result_type
|
|
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)
|
|
|
|
|
|
def ranges_like(*xs):
|
|
start = 0
|
|
for x in xs:
|
|
x_len = len(x)
|
|
yield range(start, start + x_len)
|
|
start += x_len
|
|
|
|
|
|
def remaining(original, *removed_lists):
|
|
blacklist = set(itertools.chain(*removed_lists))
|
|
return [i for i in original if i not in blacklist]
|
|
|
|
|
|
def _charswap(a, b, s):
|
|
return s.translate(string.maketrans(a+b, b+a))
|
|
|
|
|
|
def _get_sdims(dimension_numbers):
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
rhs_sdims = [i for i, c in enumerate(rhs_spec) if c not in {"I", "O"}]
|
|
lhs_sdims = sorted((i for i, c in enumerate(lhs_spec) if c not in {"N", "C"}),
|
|
key=lambda i: rhs_spec.index(lhs_spec[i]))
|
|
out_sdims = sorted((i for i, c in enumerate(out_spec) if c not in {"N", "C"}),
|
|
key=lambda i: rhs_spec.index(out_spec[i]))
|
|
return lhs_sdims, rhs_sdims, out_sdims
|
|
|
|
|
|
ConvolutionDimensionNumbers = collections.namedtuple(
|
|
"ConvolutionDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])
|
|
|
|
def _conv_general_proto(dimension_numbers):
|
|
assert type(dimension_numbers) is ConvolutionDimensionNumbers
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
|
proto = xla_bridge.xla_data_pb2.ConvolutionDimensionNumbers()
|
|
proto.input_batch_dimension = lhs_spec[0]
|
|
proto.input_feature_dimension = lhs_spec[1]
|
|
proto.output_batch_dimension = out_spec[0]
|
|
proto.output_feature_dimension = out_spec[1]
|
|
proto.kernel_output_feature_dimension = rhs_spec[0]
|
|
proto.kernel_input_feature_dimension = rhs_spec[1]
|
|
proto.input_spatial_dimensions.extend(lhs_spec[2:])
|
|
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
|
|
proto.output_spatial_dimensions.extend(out_spec[2:])
|
|
return proto
|
|
|
|
|
|
def _conv_general_vjp_lhs_padding(
|
|
in_shape, window_dimensions, window_strides, out_shape, padding,
|
|
lhs_dilation, rhs_dilation):
|
|
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
|
|
out_dilated_shape = _dilate_shape(out_shape, window_strides)
|
|
pad_before = onp.subtract(window_dimensions, [lo for lo, _ in padding]) - 1
|
|
pad_after = (onp.add(lhs_dilated_shape, window_dimensions) - 1
|
|
- out_dilated_shape - pad_before)
|
|
return zip(pad_before, pad_after)
|
|
|
|
|
|
def _conv_general_vjp_rhs_padding(
|
|
in_shape, window_dimensions, window_strides, out_shape, padding,
|
|
lhs_dilation, rhs_dilation):
|
|
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
|
|
rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
|
|
out_dilated_shape = _dilate_shape(out_shape, window_strides)
|
|
total_in_pad = out_dilated_shape + rhs_dilated_shape - lhs_dilated_shape - 1
|
|
return [(pad[0], tot - pad[0]) for pad, tot in zip(padding, total_in_pad)]
|
|
|
|
|
|
def _balanced_eq(x, z, y):
|
|
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
|
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
|
|
|
|
|
def _eq_meet(a, b):
|
|
a_dtype, b_dtype = _dtype(a), _dtype(b)
|
|
if a_dtype != b_dtype:
|
|
higher_dtype = onp.promote_types(a_dtype, b_dtype)
|
|
if higher_dtype == a_dtype:
|
|
a = convert_element_type(a, b_dtype)
|
|
else:
|
|
b = convert_element_type(b, a_dtype)
|
|
return eq(a, b)
|
|
|
|
|
|
def maybe_tracer_tuple_to_abstract_tuple(tup):
|
|
if isinstance(tup, pe.JaxprTracerTuple):
|
|
return core.AbstractTuple(map(maybe_tracer_tuple_to_abstract_tuple, tup))
|
|
elif isinstance(tup, core.AbstractValue):
|
|
return tup
|
|
elif tup is None:
|
|
return core.AbstractTuple(()) # TODO(dougalm): check this
|
|
else:
|
|
raise TypeError, tup
|
|
|
|
|
|
def subvals(lst, replace):
|
|
lst = list(lst)
|
|
for i, v in replace:
|
|
lst[i] = v
|
|
return tuple(lst)
|
|
|
|
|
|
def _abstractify(x):
|
|
# abstractify wrapper used internally for primitives like _while_loop
|
|
if isinstance(x, core.Tracer):
|
|
# TODO(mattjj,dougalm): check that it's at least ShapedArray
|
|
return pe.PartialVal((x.aval, core.unit))
|
|
else:
|
|
return pe.PartialVal((xla.abstractify(x), core.unit))
|