mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Relax dimension ordering rules for dot_general.
JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.
This commit is contained in:
parent
3fb887421b
commit
e2e73a854a
167
jax/lax/lax.py
167
jax/lax/lax.py
@ -614,35 +614,6 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
|
||||
contract_dims_seq, batch_dims_seq = dimension_numbers
|
||||
contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq))
|
||||
batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq))
|
||||
if not dtypes.issubdtype(lhs.dtype, np.inexact):
|
||||
# TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a
|
||||
# sum of products instead.
|
||||
lhs_contract_dims, rhs_contract_dims = contract_dims
|
||||
lhs_batch_dims, rhs_batch_dims = batch_dims
|
||||
lhs_noncontract_dims = tuple(sorted(
|
||||
set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
|
||||
rhs_noncontract_dims = tuple(sorted(
|
||||
set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
|
||||
lhs = transpose(lhs,
|
||||
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
|
||||
rhs = transpose(rhs,
|
||||
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
|
||||
|
||||
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
|
||||
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
|
||||
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
|
||||
|
||||
rhs_start_expand = len(lhs_batch_dims)
|
||||
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
|
||||
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
|
||||
|
||||
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
|
||||
len(rhs_noncontract_dims))
|
||||
op_product = bitwise_and if lhs.dtype == np.bool_ else mul
|
||||
op_sum = bitwise_or if lhs.dtype == np.bool_ else add
|
||||
return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
|
||||
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
|
||||
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(contract_dims, batch_dims),
|
||||
precision=_canonicalize_precision(precision))
|
||||
@ -2714,24 +2685,38 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision):
|
||||
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
||||
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
||||
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
||||
if not np.all(np.equal(lhs_batch, rhs_batch)):
|
||||
msg = ("dot_general requires same lhs and rhs batch dimension numbers, "
|
||||
"got {} and {}.")
|
||||
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
||||
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
|
||||
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
|
||||
if len(lhs_batch_set) != len(lhs_batch):
|
||||
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
|
||||
f"lhs_batch {lhs_batch}.")
|
||||
raise TypeError(msg)
|
||||
if len(rhs_batch_set) != len(rhs_batch):
|
||||
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
|
||||
f"rhs_batch {rhs_batch}.")
|
||||
raise TypeError(msg)
|
||||
if len(lhs_contracting_set) != len(lhs_contracting):
|
||||
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
|
||||
f"got lhs_contracting {lhs_contracting}.")
|
||||
raise TypeError(msg)
|
||||
if len(rhs_contracting_set) != len(rhs_contracting):
|
||||
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
|
||||
f"got rhs_contracting {rhs_contracting}.")
|
||||
raise TypeError(msg)
|
||||
if lhs_contracting_set & lhs_batch_set:
|
||||
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
|
||||
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
|
||||
raise TypeError(msg.format(lhs_batch, lhs_contracting))
|
||||
if rhs_contracting_set & rhs_batch_set:
|
||||
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
|
||||
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
|
||||
raise TypeError(msg.format(rhs_batch, rhs_contracting))
|
||||
lhs_batch_shape = np.take(lhs.shape, lhs_batch)
|
||||
rhs_batch_shape = np.take(rhs.shape, rhs_batch)
|
||||
if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)):
|
||||
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
||||
"to have the same shape, got {} and {}.")
|
||||
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
||||
if tuple(sorted(lhs_batch)) != tuple(range(len(lhs_batch))):
|
||||
msg = ("dot_general requires lhs batch dimensions to precede contracting "
|
||||
"and non-contracting dimensions, got lhs_batch {}.")
|
||||
raise TypeError(msg.format(lhs_batch))
|
||||
if tuple(sorted(rhs_batch)) != tuple(range(len(rhs_batch))):
|
||||
msg = ("dot_general requires rhs batch dimensions to precede contracting "
|
||||
"and non-contracting dimensions, got rhs_batch {}.")
|
||||
raise TypeError(msg.format(rhs_batch))
|
||||
lhs_contracting_shape = np.take(lhs.shape, lhs_contracting)
|
||||
rhs_contracting_shape = np.take(rhs.shape, rhs_contracting)
|
||||
if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)):
|
||||
@ -2739,16 +2724,16 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision):
|
||||
"shape, got {} and {}.")
|
||||
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
||||
|
||||
batch_shape = tuple(np.take(lhs.shape, lhs_batch))
|
||||
lhs_contract_or_batch = tuple(lhs_contracting) + tuple(lhs_batch)
|
||||
batch_shape = tuple(lhs_batch_shape)
|
||||
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
||||
lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch))
|
||||
rhs_contract_or_batch = tuple(rhs_contracting) + tuple(rhs_batch)
|
||||
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
||||
rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch))
|
||||
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision):
|
||||
return naryop_dtype_rule(_input_dtype, [_num, _num], 'dot_general', lhs, rhs)
|
||||
return naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
|
||||
|
||||
|
||||
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
@ -2785,53 +2770,77 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
assert lbd is not None or rbd is not None
|
||||
def bump_dims(dims, b):
|
||||
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
||||
|
||||
if lbd is not None and rbd is not None:
|
||||
# adding a batch dimension
|
||||
if lbd != 0:
|
||||
lhs = batching.moveaxis(lhs, lbd, 0)
|
||||
if rbd != 0:
|
||||
rhs = batching.moveaxis(rhs, rbd, 0)
|
||||
lhs_batch = (0,) + tuple(np.add(1, lhs_batch))
|
||||
rhs_batch = (0,) + tuple(np.add(1, rhs_batch))
|
||||
lhs_contract = tuple(np.add(1, lhs_contract))
|
||||
rhs_contract = tuple(np.add(1, rhs_contract))
|
||||
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
||||
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
result_batch_dim = 0
|
||||
else:
|
||||
# adding a tensor product dimension
|
||||
if lbd is not None:
|
||||
if lhs_batch == () or lbd > np.max(lhs_batch):
|
||||
# can avoid transposes
|
||||
bump_lhs_contract = np.greater_equal(lhs_contract, lbd)
|
||||
lhs_contract = tuple(np.add(lhs_contract, bump_lhs_contract))
|
||||
result_batch_dim = lbd - len(lhs_contract) + sum(bump_lhs_contract)
|
||||
else:
|
||||
# move the new dimension to the end of lhs to avoid changing batch dims
|
||||
lhs = batching.moveaxis(lhs, lbd, lhs.ndim - 1)
|
||||
# lhs tensor product dims in result come after batch dims
|
||||
result_batch_dim = lhs.ndim - len(lhs_contract) - 1
|
||||
other = tuple(d for d in range(lhs.ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract)
|
||||
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
|
||||
lhs_batch = bump_dims(lhs_batch, lbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
else:
|
||||
if rhs_batch == () or rbd > np.max(rhs_batch):
|
||||
# can avoid transposes
|
||||
bump_rhs_contract = np.greater_equal(rhs_contract, rbd)
|
||||
rhs_contract = tuple(np.add(rhs_contract, bump_rhs_contract))
|
||||
result_batch_dim = (rbd + (lhs.ndim - len(lhs_contract) - len(lhs_batch))
|
||||
- (len(rhs_contract) - sum(bump_rhs_contract)))
|
||||
else:
|
||||
# move the new dimension to the end of rhs to avoid changing batch dims
|
||||
rhs = batching.moveaxis(rhs, rbd, rhs.ndim - 1)
|
||||
# rhs tensor product dims in result come after batch dims + lhs tensor
|
||||
# product dims
|
||||
result_batch_dim = (lhs.ndim - len(lhs_contract) - len(lhs_batch) +
|
||||
rhs.ndim - len(rhs_contract) - 1)
|
||||
other = tuple(d for d in range(rhs.ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract)
|
||||
result_batch_dim = (lhs.ndim - len(lhs_contract) +
|
||||
sum(np.less(other, rbd)))
|
||||
rhs_batch = bump_dims(rhs_batch, rbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
|
||||
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision)
|
||||
return batched_out, int(result_batch_dim)
|
||||
|
||||
def _dot_using_sum_of_products(lhs, rhs, *, dimension_numbers):
|
||||
contract_dims, batch_dims = dimension_numbers
|
||||
lhs_contract_dims, rhs_contract_dims = contract_dims
|
||||
lhs_batch_dims, rhs_batch_dims = batch_dims
|
||||
lhs_noncontract_dims = tuple(sorted(
|
||||
set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
|
||||
rhs_noncontract_dims = tuple(sorted(
|
||||
set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
|
||||
lhs = transpose(lhs,
|
||||
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
|
||||
rhs = transpose(rhs,
|
||||
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
|
||||
|
||||
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
|
||||
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
|
||||
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
|
||||
|
||||
rhs_start_expand = len(lhs_batch_dims)
|
||||
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
|
||||
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
|
||||
|
||||
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
|
||||
len(rhs_noncontract_dims))
|
||||
op_product = bitwise_and if lhs.dtype == np.bool_ else mul
|
||||
op_sum = bitwise_or if lhs.dtype == np.bool_ else add
|
||||
return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
|
||||
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
|
||||
|
||||
def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision):
|
||||
return xops.DotGeneral(lhs, rhs,
|
||||
xc.make_dot_dimension_numbers(dimension_numbers),
|
||||
precision_config=_precision_config(precision))
|
||||
dtype = c.get_shape(lhs).numpy_dtype()
|
||||
if dtypes.issubdtype(dtype, np.inexact):
|
||||
return xops.DotGeneral(lhs, rhs,
|
||||
xc.make_dot_dimension_numbers(dimension_numbers),
|
||||
precision_config=_precision_config(precision))
|
||||
else:
|
||||
# TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a
|
||||
# sum of products instead.
|
||||
translation = xla.lower_fun(_dot_using_sum_of_products,
|
||||
multiple_results=False)
|
||||
return translation(c, lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
|
||||
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
|
||||
precision):
|
||||
|
@ -416,6 +416,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
((3, 5), (2, 5), (([1], [1]), ([], []))),
|
||||
((5, 3), (5, 2), (([0], [0]), ([], []))),
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
|
||||
((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
|
||||
]
|
||||
for dtype in float_dtypes))
|
||||
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
|
||||
|
@ -743,10 +743,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
||||
"rng_factory": rng_factory}
|
||||
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
||||
[(5,), (5,), [0], [0]],
|
||||
[(5, 7), (5,), [0], [0]],
|
||||
[(7, 5), (5,), [1], [0]],
|
||||
[(3, 5), (2, 5), [1], [1]],
|
||||
[(5, 3), (5, 2), [0], [0]],
|
||||
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
||||
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
||||
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
|
||||
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
||||
[(3, 2), (2, 4), [1], [0]],
|
||||
]
|
||||
@ -773,6 +777,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for dtype in all_dtypes
|
||||
@ -797,6 +802,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for dtype in all_dtypes
|
||||
|
@ -238,10 +238,14 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
||||
"bdims": bdims, "rng_factory": rng_factory}
|
||||
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
||||
[(5,), (5,), [0], [0]],
|
||||
[(5, 7), (5,), [0], [0]],
|
||||
[(7, 5), (5,), [1], [0]],
|
||||
[(3, 5), (2, 5), [1], [1]],
|
||||
[(5, 3), (5, 2), [0], [0]],
|
||||
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
||||
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
||||
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
|
||||
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
||||
[(3, 2), (2, 4), [1], [0]],
|
||||
]
|
||||
@ -266,6 +270,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory}
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user