add dot_general masking rules

This commit is contained in:
James Bradbury 2019-10-08 14:22:51 -07:00
parent 658882513e
commit 096a52a3a3
2 changed files with 44 additions and 1 deletions

View File

@ -2301,12 +2301,55 @@ def _dot_general_translation_rule(c, lhs, rhs, dimension_numbers, precision):
return c.DotGeneral(lhs, rhs, dimension_numbers,
precision_config=_precision_config(precision))
def _dot_general_polymorphic_shape_rule(shape_exprs, dimension_numbers,
precision):
del precision # Unused.
lhs_shape, rhs_shape = shape_exprs
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_batch_shape = [lhs_shape[d] for d in lhs_batch]
rhs_batch_shape = [rhs_shape[d] for d in rhs_batch]
if lhs_batch_shape != rhs_batch_shape: raise ShapeError
lhs_contract_shape = [lhs_shape[d] for d in lhs_contract]
rhs_contract_shape = [rhs_shape[d] for d in rhs_contract]
if lhs_contract_shape != rhs_contract_shape: raise ShapeError
lhs_tensorprod_shape = [lhs_shape[d] for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
rhs_tensorprod_shape = [rhs_shape[d] for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
return ShapeExpr(
lhs_batch_shape + lhs_tensorprod_shape + rhs_tensorprod_shape)
def _dot_general_masking_rule(padded_vals, logical_shapes, dimension_numbers,
precision):
lhs, rhs = padded_vals
lhs_shape, rhs_shape = logical_shapes
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
# we need only mask the lhs contraction dimensions
if len(lhs_contract) == 0:
return dot_general(lhs, rhs, dimension_numbers, precision=precision)
else:
masks = [broadcasted_iota(onp.int32, lhs.shape, d) < lhs_shape[d]
for d in lhs_contract]
mask_intersection = masks[0]
for mask in masks[1:]:
mask_intersection &= mask
masked_lhs = select(mask_intersection, lhs, zeros_like_array(lhs))
return dot_general(masked_lhs, rhs, dimension_numbers, precision=precision)
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general',
_dot_general_translation_rule)
ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
masking.shape_rules[dot_general_p] = _dot_general_polymorphic_shape_rule
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
def _broadcast_shape_rule(operand, sizes):

View File

@ -64,7 +64,7 @@ class MaskingTest(jtu.JaxTestCase):
def thunk():
@shapecheck(['(m, n)', 'n'], 'm')
def matvec(A, b):
return np.dot(b, A)
return lax.dot_general(A, b, [((0,), (0,)), ((), ())])
self.assertRaisesRegex(ShapeError, "", thunk)
def test_flatten_shape_checking(self):