Add support for integer dot operations.

Lower to a sum of products for integers since XLA currently lacks support for integer dots.
This commit is contained in:
Peter Hawkins 2019-06-06 17:21:21 -04:00
parent 8309836dd0
commit 15361c4dde
3 changed files with 48 additions and 11 deletions

View File

@ -457,6 +457,17 @@ def dot(lhs, rhs):
Returns:
An array containing the product.
"""
# XLA doesn't support integer dots, so emit a sum of products instead.
if onp.issubdtype(lhs.dtype, onp.integer):
lhs_shape = onp.shape(lhs)
lhs_ndim = len(lhs_shape)
rhs_ndim = onp.ndim(rhs)
if rhs_ndim > 1:
lhs = broadcast_in_dim(lhs, lhs_shape + (1,), tuple(range(len(lhs_shape))))
if lhs_ndim > 1:
rhs = broadcast(rhs, (1,))
return reduce(mul(lhs, rhs), _zero(lhs), add, (len(lhs_shape) - 1,))
return dot_p.bind(lhs, rhs)
def dot_general(lhs, rhs, dimension_numbers):
@ -476,9 +487,35 @@ def dot_general(lhs, rhs, dimension_numbers):
Returns:
An array containing the result.
"""
lhs_dims, rhs_dims = dimension_numbers
dimension_numbers = (tuple(map(tuple, lhs_dims)), tuple(map(tuple, rhs_dims)))
return dot_general_p.bind(lhs, rhs, dimension_numbers=dimension_numbers)
contract_dims, batch_dims = dimension_numbers
contract_dims = tuple(map(tuple, contract_dims))
batch_dims = tuple(map(tuple, batch_dims))
if onp.issubdtype(lhs.dtype, onp.integer):
# XLA doesn't support integer dots, so emit a sum of products instead.
lhs_contract_dims, rhs_contract_dims = contract_dims
lhs_batch_dims, rhs_batch_dims = batch_dims
lhs_noncontract_dims = tuple(sorted(
set(range(onp.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
rhs_noncontract_dims = tuple(sorted(
set(range(onp.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
lhs = transpose(lhs,
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
rhs = transpose(rhs,
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
new_lhs_shape = onp.insert(
onp.shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims),
(1,) * len(rhs_noncontract_dims))
new_rhs_shape = onp.insert(onp.shape(rhs), len(lhs_batch_dims),
(1,) * len(lhs_noncontract_dims))
lhs = reshape(lhs, new_lhs_shape)
rhs = reshape(rhs, new_rhs_shape)
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
len(rhs_noncontract_dims))
return reduce(mul(lhs, rhs), _zero(lhs), add,
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(contract_dims, batch_dims))
def broadcast(operand, sizes):
"""Broadcasts an array, adding new major dimensions.

View File

@ -497,7 +497,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
("tensor-matrix", (4, 3, 2), (2, 5)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
@ -523,7 +523,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
("tensor-matrix", (5, 2, 3), (3, 2)),
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
@ -546,7 +546,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
]
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)

View File

@ -588,7 +588,7 @@ class LaxTest(jtu.JaxTestCase):
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": rng}
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for dtype in float_dtypes
for dtype in default_dtypes
for rng in [jtu.rand_default()]))
def testDot(self, lhs_shape, rhs_shape, dtype, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
@ -601,7 +601,7 @@ class LaxTest(jtu.JaxTestCase):
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": rng}
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for dtype in float_dtypes
for dtype in default_dtypes
for rng in [jtu.rand_default()]))
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
@ -626,7 +626,7 @@ class LaxTest(jtu.JaxTestCase):
# [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
[(3, 2), (2, 4), [1], [0]],
]
for dtype in float_dtypes
for dtype in default_dtypes
for rng in [jtu.rand_small()]))
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
lhs_contracting, rhs_contracting, rng):
@ -650,7 +650,7 @@ class LaxTest(jtu.JaxTestCase):
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in float_dtypes
for dtype in default_dtypes
for rng in [jtu.rand_small()]))
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, rng):
@ -673,7 +673,7 @@ class LaxTest(jtu.JaxTestCase):
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in float_dtypes
for dtype in default_dtypes
for rng in [jtu.rand_small()]))
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, rng):