mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add support for integer dot operations.
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
This commit is contained in:
parent
8309836dd0
commit
15361c4dde
@ -457,6 +457,17 @@ def dot(lhs, rhs):
|
||||
Returns:
|
||||
An array containing the product.
|
||||
"""
|
||||
# XLA doesn't support integer dots, so emit a sum of products instead.
|
||||
if onp.issubdtype(lhs.dtype, onp.integer):
|
||||
lhs_shape = onp.shape(lhs)
|
||||
lhs_ndim = len(lhs_shape)
|
||||
rhs_ndim = onp.ndim(rhs)
|
||||
if rhs_ndim > 1:
|
||||
lhs = broadcast_in_dim(lhs, lhs_shape + (1,), tuple(range(len(lhs_shape))))
|
||||
if lhs_ndim > 1:
|
||||
rhs = broadcast(rhs, (1,))
|
||||
return reduce(mul(lhs, rhs), _zero(lhs), add, (len(lhs_shape) - 1,))
|
||||
|
||||
return dot_p.bind(lhs, rhs)
|
||||
|
||||
def dot_general(lhs, rhs, dimension_numbers):
|
||||
@ -476,9 +487,35 @@ def dot_general(lhs, rhs, dimension_numbers):
|
||||
Returns:
|
||||
An array containing the result.
|
||||
"""
|
||||
lhs_dims, rhs_dims = dimension_numbers
|
||||
dimension_numbers = (tuple(map(tuple, lhs_dims)), tuple(map(tuple, rhs_dims)))
|
||||
return dot_general_p.bind(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
contract_dims, batch_dims = dimension_numbers
|
||||
contract_dims = tuple(map(tuple, contract_dims))
|
||||
batch_dims = tuple(map(tuple, batch_dims))
|
||||
if onp.issubdtype(lhs.dtype, onp.integer):
|
||||
# XLA doesn't support integer dots, so emit a sum of products instead.
|
||||
lhs_contract_dims, rhs_contract_dims = contract_dims
|
||||
lhs_batch_dims, rhs_batch_dims = batch_dims
|
||||
lhs_noncontract_dims = tuple(sorted(
|
||||
set(range(onp.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
|
||||
rhs_noncontract_dims = tuple(sorted(
|
||||
set(range(onp.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
|
||||
lhs = transpose(lhs,
|
||||
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
|
||||
rhs = transpose(rhs,
|
||||
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
|
||||
new_lhs_shape = onp.insert(
|
||||
onp.shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims),
|
||||
(1,) * len(rhs_noncontract_dims))
|
||||
new_rhs_shape = onp.insert(onp.shape(rhs), len(lhs_batch_dims),
|
||||
(1,) * len(lhs_noncontract_dims))
|
||||
lhs = reshape(lhs, new_lhs_shape)
|
||||
rhs = reshape(rhs, new_rhs_shape)
|
||||
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
|
||||
len(rhs_noncontract_dims))
|
||||
return reduce(mul(lhs, rhs), _zero(lhs), add,
|
||||
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
|
||||
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(contract_dims, batch_dims))
|
||||
|
||||
def broadcast(operand, sizes):
|
||||
"""Broadcasts an array, adding new major dimensions.
|
||||
|
@ -497,7 +497,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
("tensor-matrix", (4, 3, 2), (2, 5)),
|
||||
("matrix-tensor", (5, 2), (3, 2, 4)),
|
||||
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
||||
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
||||
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
||||
self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
|
||||
@ -523,7 +523,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
("tensor-matrix", (5, 2, 3), (3, 2)),
|
||||
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
|
||||
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
||||
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
||||
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
||||
self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
|
||||
@ -546,7 +546,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
|
||||
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
|
||||
]
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)))
|
||||
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
||||
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
|
||||
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
||||
lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)
|
||||
|
@ -588,7 +588,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"rng": rng}
|
||||
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
||||
for dtype in float_dtypes
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testDot(self, lhs_shape, rhs_shape, dtype, rng):
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
@ -601,7 +601,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"rng": rng}
|
||||
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
||||
for dtype in float_dtypes
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng):
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
@ -626,7 +626,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
||||
[(3, 2), (2, 4), [1], [0]],
|
||||
]
|
||||
for dtype in float_dtypes
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_small()]))
|
||||
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
||||
lhs_contracting, rhs_contracting, rng):
|
||||
@ -650,7 +650,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for dtype in float_dtypes
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_small()]))
|
||||
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
||||
dimension_numbers, rng):
|
||||
@ -673,7 +673,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for dtype in float_dtypes
|
||||
for dtype in default_dtypes
|
||||
for rng in [jtu.rand_small()]))
|
||||
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
|
||||
dimension_numbers, rng):
|
||||
|
Loading…
x
Reference in New Issue
Block a user