remove dot primitive in favor of dot_general

This commit is contained in:
James Bradbury 2019-10-08 14:23:30 -07:00
parent 096a52a3a3
commit 6d29c4e352
2 changed files with 12 additions and 165 deletions

View File

@ -508,7 +508,12 @@ def dot(lhs, rhs, precision=None):
rhs = broadcast(rhs, (1,))
return reduce(mul(lhs, rhs), _zero(lhs), add, (len(lhs_shape) - 1,))
return dot_p.bind(lhs, rhs, precision=_canonicalize_precision(precision))
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]:
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision)
else:
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
lhs.shape, rhs.shape))
def dot_general(lhs, rhs, dimension_numbers, precision=None):
"""More general contraction operator.
@ -2019,109 +2024,6 @@ def _reshape_axis_out_of(src, size1, x):
shape[src:src+1] = [size1, size2]
return reshape(x, shape)
def _dot_shape_rule(lhs, rhs, precision):
if lhs.ndim == 0 or rhs.ndim == 0:
msg = "Dot only supports rank 1 or above, got shapes {} and {}."
raise TypeError(msg.format(lhs.shape, rhs.shape))
if lhs.ndim > 2 or rhs.ndim > 2:
msg = "Dot only supports rank 2 or less, got shapes {} and {}."
raise TypeError(msg.format(lhs.shape, rhs.shape))
def require(shape_cond):
if not shape_cond:
msg = "Incompatible shapes for dot: got {} and {}."
raise TypeError(msg.format(lhs.shape, rhs.shape))
if lhs.ndim == rhs.ndim == 1:
require(lhs.shape == rhs.shape)
return ()
elif lhs.ndim == rhs.ndim == 2:
require(lhs.shape[1] == rhs.shape[0])
return (lhs.shape[0], rhs.shape[1])
elif rhs.ndim == 1:
require(lhs.shape[-1] == rhs.shape[0])
return lhs.shape[:-1]
else:
require(lhs.shape[-1] == rhs.shape[-2])
return lhs.shape[:-1] + rhs.shape[:-2] + rhs.shape[-1:]
def _dot_transpose_lhs(t, rhs, precision):
if onp.ndim(t) == onp.ndim(rhs) == 2:
return dot(t, transpose(rhs, (1, 0)), precision=precision)
elif onp.ndim(t) == 1 and onp.ndim(rhs) == 2:
return dot(rhs, t, precision=precision)
elif onp.ndim(t) == onp.ndim(rhs) == 1:
return _outer(t, rhs)
elif onp.ndim(t) == 0 or onp.ndim(rhs) == 0:
return mul(t, rhs)
else:
raise TypeError
def _dot_transpose_rhs(t, lhs, precision):
if onp.ndim(lhs) == onp.ndim(t) == 2:
return dot(transpose(lhs, (1, 0)), t)
elif onp.ndim(lhs) == 2 and onp.ndim(t) == 1:
return dot(t, lhs, precision=precision)
elif onp.ndim(t) == onp.ndim(lhs) == 1:
return _outer(lhs, t)
elif onp.ndim(t) == 0 or onp.ndim(lhs) == 0:
return mul(t, lhs)
else:
raise TypeError
def _outer(x, y):
assert onp.ndim(x) == onp.ndim(y) == 1
return mul(reshape(x, (x.shape[0], 1)), reshape(y, (1, y.shape[0])))
def _dot_batch_rule(batched_args, batch_dims, precision=None):
lhs, rhs = batched_args
lbd, rbd = batch_dims
T = lambda x: transpose(x, onp.arange(onp.ndim(x))[::-1])
# in some cases, we can call dot instead of dot_general
if max(onp.ndim(lhs), onp.ndim(rhs)) <= 2:
if rbd is None:
assert lbd in (0, 1)
if lbd == 0:
return dot(lhs, rhs, precision=precision), 0
else:
return dot(T(rhs), lhs, precision=precision), onp.ndim(rhs) - 1
if lbd is None:
assert rbd in (0, 1)
if rbd == onp.ndim(rhs) - 1:
return dot(lhs, rhs, precision=precision), onp.ndim(lhs) - 1
else:
return dot(rhs, T(lhs), precision=precision), 0
assert lbd is not None and rbd is not None
assert lhs.ndim == rhs.ndim == 2 # dot only supports rank 1 and above
lhs = batching.moveaxis(lhs, lbd, 0)
rhs = batching.moveaxis(rhs, rbd, 0)
out = dot_general(lhs, rhs, [((1,), (1,)), ((0,), (0,))],
precision=precision)
return out, 0
if lbd is None:
assert rbd is not None
lhs = broadcast(lhs, (rhs.shape[rbd],))
else:
lhs = batching.moveaxis(lhs, lbd, 0)
lhs_batch = (0,)
lhs_contracting = (onp.ndim(lhs) - 1,)
if rbd is None:
assert lbd is not None
rhs = broadcast(rhs, (lhs.shape[0],))
else:
rhs = batching.moveaxis(rhs, rbd, 0)
rhs_batch = (0,)
rhs_contracting = (onp.arange(1, onp.ndim(rhs))[-2:][0],)
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
return dot_general(lhs, rhs, dim_nums, precision=precision), 0
def _precision_config(precision):
if precision is not None:
config = xla_client.PrecisionConfig()
@ -2129,55 +2031,6 @@ def _precision_config(precision):
return config
return None
def _dot_translation_rule(c, lhs, rhs, precision):
return c.Dot(lhs, rhs, precision_config=_precision_config(precision))
def _dot_polymorphic_shape_rule(shape_exprs, precision):
del precision # Unused.
lhs_shape, rhs_shape = shape_exprs
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
if lhs_ndim == rhs_ndim == 1:
if not lhs_shape == rhs_shape: raise ShapeError
return ShapeExpr(())
elif lhs_ndim == rhs_ndim == 2:
if not lhs_shape[1] == rhs_shape[0]: raise ShapeError
return ShapeExpr((lhs_shape[0], rhs_shape[1]))
elif rhs_ndim == 1:
if not lhs_shape[1] == rhs_shape[0]: raise ShapeError
return ShapeExpr((lhs_shape[0],))
else:
if not lhs_shape[0] == rhs_shape[0]: raise ShapeError
return ShapeExpr((rhs_shape[1],))
def _dot_masking_rule(padded_vals, logical_shapes, precision):
lhs, rhs = padded_vals
lhs_shape, rhs_shape = logical_shapes
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
if lhs_ndim == rhs_ndim == 1:
masked_lhs = select(iota(onp.int32, lhs.shape[0]) < lhs_shape[0],
lhs, zeros_like_array(lhs))
return dot_p.bind(masked_lhs, rhs, precision=precision)
elif lhs_ndim == rhs_ndim == 2:
# TODO could avoid select if we check whether contracted axis is masked
masked_lhs = select(broadcasted_iota(onp.int32, lhs.shape, 1) < lhs_shape[1],
lhs, zeros_like_array(lhs))
return dot_p.bind(masked_lhs, rhs, precision=precision)
elif rhs_ndim == 1:
raise NotImplementedError
else:
raise NotImplementedError
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot',
_dot_translation_rule)
ad.defbilinear(dot_p, _dot_transpose_lhs, _dot_transpose_rhs)
batching.primitive_batchers[dot_p] = _dot_batch_rule
masking.shape_rules[dot_p] = _dot_polymorphic_shape_rule
masking.masking_rules[dot_p] = _dot_masking_rule
def _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
@ -2276,8 +2129,8 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
else:
# move the new dimension to the end of lhs to avoid changing batch dims
lhs = batching.moveaxis(lhs, lbd, lhs.ndim - 1)
# lhs tensor product dims in result come after batch dims
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
# lhs tensor product dims in result come after batch dims
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
else:
if rhs_batch == () or rbd > onp.max(rhs_batch):
# can avoid transposes
@ -2288,10 +2141,10 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
else:
# move the new dimension to the end of rhs to avoid changing batch dims
rhs = batching.moveaxis(rhs, rbd, rhs.ndim - 1)
# rhs tensor product dims in result come after batch dims + lhs tensor
# product dims
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
rhs.ndim - len(rhs_contract) - 1)
# rhs tensor product dims in result come after batch dims + lhs tensor
# product dims
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
rhs.ndim - len(rhs_contract) - 1)
new_dimension_numbers = [(lhs_contract, rhs_contract), (lhs_batch, rhs_batch)]
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision)

View File

@ -411,11 +411,6 @@ _defreducer(lax.reduce_max_p, pmax_p)
_defreducer(lax.reduce_min_p, pmin_p)
def _dot_papply_rule(name, size, vals, dims, precision):
x, _ = vals
dim_nums = [((x.ndim,), (0,)), ((), ())]
return _dot_general_papply_rule(name, size, vals, dims, dim_nums, precision)
def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers,
precision):
x, y = vals
@ -700,7 +695,6 @@ def _gather_papply_rule(
raise NotImplementedError
parallel.papply_primitive_rules[lax.dot_p] = _dot_papply_rule
parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule
parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule
parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule