mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
add dot_general masking rules
This commit is contained in:
parent
658882513e
commit
096a52a3a3
@ -2301,12 +2301,55 @@ def _dot_general_translation_rule(c, lhs, rhs, dimension_numbers, precision):
|
||||
return c.DotGeneral(lhs, rhs, dimension_numbers,
|
||||
precision_config=_precision_config(precision))
|
||||
|
||||
def _dot_general_polymorphic_shape_rule(shape_exprs, dimension_numbers,
|
||||
precision):
|
||||
del precision # Unused.
|
||||
lhs_shape, rhs_shape = shape_exprs
|
||||
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
lhs_batch_shape = [lhs_shape[d] for d in lhs_batch]
|
||||
rhs_batch_shape = [rhs_shape[d] for d in rhs_batch]
|
||||
if lhs_batch_shape != rhs_batch_shape: raise ShapeError
|
||||
|
||||
lhs_contract_shape = [lhs_shape[d] for d in lhs_contract]
|
||||
rhs_contract_shape = [rhs_shape[d] for d in rhs_contract]
|
||||
if lhs_contract_shape != rhs_contract_shape: raise ShapeError
|
||||
|
||||
lhs_tensorprod_shape = [lhs_shape[d] for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract]
|
||||
rhs_tensorprod_shape = [rhs_shape[d] for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
return ShapeExpr(
|
||||
lhs_batch_shape + lhs_tensorprod_shape + rhs_tensorprod_shape)
|
||||
|
||||
def _dot_general_masking_rule(padded_vals, logical_shapes, dimension_numbers,
|
||||
precision):
|
||||
lhs, rhs = padded_vals
|
||||
lhs_shape, rhs_shape = logical_shapes
|
||||
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
# we need only mask the lhs contraction dimensions
|
||||
if len(lhs_contract) == 0:
|
||||
return dot_general(lhs, rhs, dimension_numbers, precision=precision)
|
||||
else:
|
||||
masks = [broadcasted_iota(onp.int32, lhs.shape, d) < lhs_shape[d]
|
||||
for d in lhs_contract]
|
||||
mask_intersection = masks[0]
|
||||
for mask in masks[1:]:
|
||||
mask_intersection &= mask
|
||||
masked_lhs = select(mask_intersection, lhs, zeros_like_array(lhs))
|
||||
return dot_general(masked_lhs, rhs, dimension_numbers, precision=precision)
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general',
|
||||
_dot_general_translation_rule)
|
||||
ad.defbilinear(dot_general_p,
|
||||
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
||||
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
masking.shape_rules[dot_general_p] = _dot_general_polymorphic_shape_rule
|
||||
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
|
||||
|
||||
|
||||
def _broadcast_shape_rule(operand, sizes):
|
||||
|
@ -64,7 +64,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def thunk():
|
||||
@shapecheck(['(m, n)', 'n'], 'm')
|
||||
def matvec(A, b):
|
||||
return np.dot(b, A)
|
||||
return lax.dot_general(A, b, [((0,), (0,)), ((), ())])
|
||||
self.assertRaisesRegex(ShapeError, "", thunk)
|
||||
|
||||
def test_flatten_shape_checking(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user