mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fix broadcasted eye bug, enable more einsum
This commit is contained in:
parent
6a71e9d6ec
commit
13a0e1168e
@ -1060,6 +1060,8 @@ def tensordot(a, b, axes=2):
|
||||
def einsum(*operands):
|
||||
operands, contractions = opt_einsum.contract_path(
|
||||
*operands, einsum_call=True, use_blas=True)
|
||||
sum = lambda x, axes: lax.reduce(x, onp.array(0, x.dtype), lax.add, axes)
|
||||
|
||||
for operand_indices, contracted_names, einstr, _, _ in contractions:
|
||||
input_str, result_names = einstr.split('->')
|
||||
input_names = input_str.split(',')
|
||||
@ -1075,14 +1077,21 @@ def einsum(*operands):
|
||||
uniques = tuple(name for name in contracted_names if counts[name] == 1)
|
||||
if uniques:
|
||||
axes = tuple(names.index(name) for name in uniques)
|
||||
operand = lax.reduce(operand, onp.array(0, _dtype(operand)), lax.add, axes)
|
||||
operand = sum(operand, axes)
|
||||
names = names.translate(None, ''.join(uniques))
|
||||
map(counts.pop, uniques)
|
||||
|
||||
# for every repeated index, do a contraction against an identity matrix
|
||||
for name, count in counts.items():
|
||||
if count > 1:
|
||||
raise NotImplementedError
|
||||
axes = [i for i, n in enumerate(names) if n == name]
|
||||
eye = lax.broadcasted_eye(operand.dtype, operand.shape, axes)
|
||||
if name not in result_names:
|
||||
operand = sum(operand * eye, axes)
|
||||
names = names.replace(name, '')
|
||||
else:
|
||||
operand = sum(operand * eye, axes[:-1])
|
||||
names = names.replace(name, '', count - 1)
|
||||
|
||||
result = operand
|
||||
|
||||
|
@ -73,25 +73,46 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
s = '...ijk->ki'
|
||||
check(s, x)
|
||||
|
||||
# def test_one_operand_7(self):
|
||||
# x = rng().randn(3, 3, 3)
|
||||
# s = 'iii->'
|
||||
# check(s, x)
|
||||
def test_one_operand_7(self):
|
||||
x = rng().randn(3, 3)
|
||||
s = 'ii->'
|
||||
check(s, x)
|
||||
|
||||
# def test_one_operand_8(self):
|
||||
# x = rng().randn(3, 3)
|
||||
# s = 'ii->i'
|
||||
# check(s, x)
|
||||
def test_one_operand_8(self):
|
||||
x = rng().randn(3, 3, 3)
|
||||
s = 'iii->'
|
||||
check(s, x)
|
||||
|
||||
# def test_one_operand_9(self):
|
||||
# x = rng().randn(3, 3, 4)
|
||||
# s = 'iij->i'
|
||||
# check(s, x)
|
||||
def test_one_operand_9(self):
|
||||
x = rng().randn(3, 3)
|
||||
s = 'ii->i'
|
||||
check(s, x)
|
||||
|
||||
def test_one_operand_10(self):
|
||||
x = rng().randn(3, 3, 4)
|
||||
s = 'iij->i'
|
||||
check(s, x)
|
||||
|
||||
def test_one_operand_11(self):
|
||||
x = rng().randn(3, 3, 3)
|
||||
s = 'iii->i'
|
||||
check(s, x)
|
||||
|
||||
def test_one_operand_12(self):
|
||||
x = rng().randn(3, 3, 5, 4, 4)
|
||||
s = 'iijkk->i'
|
||||
check(s, x)
|
||||
|
||||
def test_one_operand_13(self):
|
||||
x = rng().randn(3, 3, 5, 4, 4)
|
||||
s = 'iijkk->ik'
|
||||
check(s, x)
|
||||
|
||||
def test_one_operand_14(self):
|
||||
x = rng().randn(3, 3, 5, 4, 4)
|
||||
s = 'iijkl->il'
|
||||
check(s, x)
|
||||
|
||||
# def test_one_operand_10(self):
|
||||
# x = rng().randn(3, 3, 3)
|
||||
# s = 'iii->i'
|
||||
# check(s, x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user