Relax dimension ordering rules for dot_general.

JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions.

Since batch dimensions are now allowed in arbitrary positions, it's not hard to
generalize the dot_general batching rule to avoid performing any transposes
(#2972).

In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.
This commit is contained in:
Peter Hawkins 2020-07-16 16:23:27 -04:00
parent 3fb887421b
commit e2e73a854a
4 changed files with 101 additions and 79 deletions

View File

@ -614,35 +614,6 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
contract_dims_seq, batch_dims_seq = dimension_numbers
contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq))
batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq))
if not dtypes.issubdtype(lhs.dtype, np.inexact):
# TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a
# sum of products instead.
lhs_contract_dims, rhs_contract_dims = contract_dims
lhs_batch_dims, rhs_batch_dims = batch_dims
lhs_noncontract_dims = tuple(sorted(
set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
rhs_noncontract_dims = tuple(sorted(
set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
lhs = transpose(lhs,
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
rhs = transpose(rhs,
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
rhs_start_expand = len(lhs_batch_dims)
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
len(rhs_noncontract_dims))
op_product = bitwise_and if lhs.dtype == np.bool_ else mul
op_sum = bitwise_or if lhs.dtype == np.bool_ else add
return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(contract_dims, batch_dims),
precision=_canonicalize_precision(precision))
@ -2714,24 +2685,38 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision):
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
"dimensions, got lhs_batch {} and rhs_batch {}.")
raise TypeError(msg.format(lhs_batch, rhs_batch))
if not np.all(np.equal(lhs_batch, rhs_batch)):
msg = ("dot_general requires same lhs and rhs batch dimension numbers, "
"got {} and {}.")
raise TypeError(msg.format(lhs_batch, rhs_batch))
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
if len(lhs_batch_set) != len(lhs_batch):
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
f"lhs_batch {lhs_batch}.")
raise TypeError(msg)
if len(rhs_batch_set) != len(rhs_batch):
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
f"rhs_batch {rhs_batch}.")
raise TypeError(msg)
if len(lhs_contracting_set) != len(lhs_contracting):
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
f"got lhs_contracting {lhs_contracting}.")
raise TypeError(msg)
if len(rhs_contracting_set) != len(rhs_contracting):
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
f"got rhs_contracting {rhs_contracting}.")
raise TypeError(msg)
if lhs_contracting_set & lhs_batch_set:
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
raise TypeError(msg.format(lhs_batch, lhs_contracting))
if rhs_contracting_set & rhs_batch_set:
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
raise TypeError(msg.format(rhs_batch, rhs_contracting))
lhs_batch_shape = np.take(lhs.shape, lhs_batch)
rhs_batch_shape = np.take(rhs.shape, rhs_batch)
if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)):
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
"to have the same shape, got {} and {}.")
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
if tuple(sorted(lhs_batch)) != tuple(range(len(lhs_batch))):
msg = ("dot_general requires lhs batch dimensions to precede contracting "
"and non-contracting dimensions, got lhs_batch {}.")
raise TypeError(msg.format(lhs_batch))
if tuple(sorted(rhs_batch)) != tuple(range(len(rhs_batch))):
msg = ("dot_general requires rhs batch dimensions to precede contracting "
"and non-contracting dimensions, got rhs_batch {}.")
raise TypeError(msg.format(rhs_batch))
lhs_contracting_shape = np.take(lhs.shape, lhs_contracting)
rhs_contracting_shape = np.take(rhs.shape, rhs_contracting)
if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)):
@ -2739,16 +2724,16 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision):
"shape, got {} and {}.")
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
batch_shape = tuple(np.take(lhs.shape, lhs_batch))
lhs_contract_or_batch = tuple(lhs_contracting) + tuple(lhs_batch)
batch_shape = tuple(lhs_batch_shape)
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch))
rhs_contract_or_batch = tuple(rhs_contracting) + tuple(rhs_batch)
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch))
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision):
return naryop_dtype_rule(_input_dtype, [_num, _num], 'dot_general', lhs, rhs)
return naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
@ -2785,53 +2770,77 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
lhs, rhs = batched_args
lbd, rbd = batch_dims
assert lbd is not None or rbd is not None
def bump_dims(dims, b):
return tuple(np.add(dims, np.greater_equal(dims, b)))
if lbd is not None and rbd is not None:
# adding a batch dimension
if lbd != 0:
lhs = batching.moveaxis(lhs, lbd, 0)
if rbd != 0:
rhs = batching.moveaxis(rhs, rbd, 0)
lhs_batch = (0,) + tuple(np.add(1, lhs_batch))
rhs_batch = (0,) + tuple(np.add(1, rhs_batch))
lhs_contract = tuple(np.add(1, lhs_contract))
rhs_contract = tuple(np.add(1, rhs_contract))
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
lhs_contract = bump_dims(lhs_contract, lbd)
rhs_contract = bump_dims(rhs_contract, rbd)
result_batch_dim = 0
else:
# adding a tensor product dimension
if lbd is not None:
if lhs_batch == () or lbd > np.max(lhs_batch):
# can avoid transposes
bump_lhs_contract = np.greater_equal(lhs_contract, lbd)
lhs_contract = tuple(np.add(lhs_contract, bump_lhs_contract))
result_batch_dim = lbd - len(lhs_contract) + sum(bump_lhs_contract)
else:
# move the new dimension to the end of lhs to avoid changing batch dims
lhs = batching.moveaxis(lhs, lbd, lhs.ndim - 1)
# lhs tensor product dims in result come after batch dims
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
other = tuple(d for d in range(lhs.ndim)
if d not in lhs_batch and d not in lhs_contract)
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
lhs_batch = bump_dims(lhs_batch, lbd)
lhs_contract = bump_dims(lhs_contract, lbd)
else:
if rhs_batch == () or rbd > np.max(rhs_batch):
# can avoid transposes
bump_rhs_contract = np.greater_equal(rhs_contract, rbd)
rhs_contract = tuple(np.add(rhs_contract, bump_rhs_contract))
result_batch_dim = (rbd + (lhs.ndim - len(lhs_contract) - len(lhs_batch))
- (len(rhs_contract) - sum(bump_rhs_contract)))
else:
# move the new dimension to the end of rhs to avoid changing batch dims
rhs = batching.moveaxis(rhs, rbd, rhs.ndim - 1)
# rhs tensor product dims in result come after batch dims + lhs tensor
# product dims
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
rhs.ndim - len(rhs_contract) - 1)
other = tuple(d for d in range(rhs.ndim)
if d not in rhs_batch and d not in rhs_contract)
result_batch_dim = (lhs.ndim - len(lhs_contract) +
sum(np.less(other, rbd)))
rhs_batch = bump_dims(rhs_batch, rbd)
rhs_contract = bump_dims(rhs_contract, rbd)
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision)
return batched_out, int(result_batch_dim)
def _dot_using_sum_of_products(lhs, rhs, *, dimension_numbers):
contract_dims, batch_dims = dimension_numbers
lhs_contract_dims, rhs_contract_dims = contract_dims
lhs_batch_dims, rhs_batch_dims = batch_dims
lhs_noncontract_dims = tuple(sorted(
set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
rhs_noncontract_dims = tuple(sorted(
set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
lhs = transpose(lhs,
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
rhs = transpose(rhs,
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
rhs_start_expand = len(lhs_batch_dims)
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
len(rhs_noncontract_dims))
op_product = bitwise_and if lhs.dtype == np.bool_ else mul
op_sum = bitwise_or if lhs.dtype == np.bool_ else add
return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision):
return xops.DotGeneral(lhs, rhs,
xc.make_dot_dimension_numbers(dimension_numbers),
precision_config=_precision_config(precision))
dtype = c.get_shape(lhs).numpy_dtype()
if dtypes.issubdtype(dtype, np.inexact):
return xops.DotGeneral(lhs, rhs,
xc.make_dot_dimension_numbers(dimension_numbers),
precision_config=_precision_config(precision))
else:
# TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a
# sum of products instead.
translation = xla.lower_fun(_dot_using_sum_of_products,
multiple_results=False)
return translation(c, lhs, rhs, dimension_numbers=dimension_numbers)
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
precision):

View File

@ -416,6 +416,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
((3, 5), (2, 5), (([1], [1]), ([], []))),
((5, 3), (5, 2), (([0], [0]), ([], []))),
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
]
for dtype in float_dtypes))
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,

View File

@ -743,10 +743,14 @@ class LaxTest(jtu.JaxTestCase):
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
"rng_factory": rng_factory}
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
[(5,), (5,), [0], [0]],
[(5, 7), (5,), [0], [0]],
[(7, 5), (5,), [1], [0]],
[(3, 5), (2, 5), [1], [1]],
[(5, 3), (5, 2), [0], [0]],
[(5, 3, 2), (5, 2, 4), [0], [0]],
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
[(3, 2), (2, 4), [1], [0]],
]
@ -773,6 +777,7 @@ class LaxTest(jtu.JaxTestCase):
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in all_dtypes
@ -797,6 +802,7 @@ class LaxTest(jtu.JaxTestCase):
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in all_dtypes

View File

@ -238,10 +238,14 @@ class LaxVmapTest(jtu.JaxTestCase):
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
"bdims": bdims, "rng_factory": rng_factory}
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
[(5,), (5,), [0], [0]],
[(5, 7), (5,), [0], [0]],
[(7, 5), (5,), [1], [0]],
[(3, 5), (2, 5), [1], [1]],
[(5, 3), (5, 2), [0], [0]],
[(5, 3, 2), (5, 2, 4), [0], [0]],
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
[(3, 2), (2, 4), [1], [0]],
]
@ -266,6 +270,7 @@ class LaxVmapTest(jtu.JaxTestCase):
"dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for bdims in all_bdims(lhs_shape, rhs_shape)