fix broadcasted eye bug, enable more einsum

This commit is contained in:
Matthew Johnson 2018-12-17 18:16:20 -08:00
parent 6a71e9d6ec
commit 13a0e1168e
2 changed files with 48 additions and 18 deletions

View File

@ -1060,6 +1060,8 @@ def tensordot(a, b, axes=2):
def einsum(*operands):
operands, contractions = opt_einsum.contract_path(
*operands, einsum_call=True, use_blas=True)
sum = lambda x, axes: lax.reduce(x, onp.array(0, x.dtype), lax.add, axes)
for operand_indices, contracted_names, einstr, _, _ in contractions:
input_str, result_names = einstr.split('->')
input_names = input_str.split(',')
@ -1075,14 +1077,21 @@ def einsum(*operands):
uniques = tuple(name for name in contracted_names if counts[name] == 1)
if uniques:
axes = tuple(names.index(name) for name in uniques)
operand = lax.reduce(operand, onp.array(0, _dtype(operand)), lax.add, axes)
operand = sum(operand, axes)
names = names.translate(None, ''.join(uniques))
map(counts.pop, uniques)
# for every repeated index, do a contraction against an identity matrix
for name, count in counts.items():
if count > 1:
raise NotImplementedError
axes = [i for i, n in enumerate(names) if n == name]
eye = lax.broadcasted_eye(operand.dtype, operand.shape, axes)
if name not in result_names:
operand = sum(operand * eye, axes)
names = names.replace(name, '')
else:
operand = sum(operand * eye, axes[:-1])
names = names.replace(name, '', count - 1)
result = operand

View File

@ -73,25 +73,46 @@ class EinsumTest(jtu.JaxTestCase):
s = '...ijk->ki'
check(s, x)
# def test_one_operand_7(self):
# x = rng().randn(3, 3, 3)
# s = 'iii->'
# check(s, x)
def test_one_operand_7(self):
x = rng().randn(3, 3)
s = 'ii->'
check(s, x)
# def test_one_operand_8(self):
# x = rng().randn(3, 3)
# s = 'ii->i'
# check(s, x)
def test_one_operand_8(self):
x = rng().randn(3, 3, 3)
s = 'iii->'
check(s, x)
# def test_one_operand_9(self):
# x = rng().randn(3, 3, 4)
# s = 'iij->i'
# check(s, x)
def test_one_operand_9(self):
x = rng().randn(3, 3)
s = 'ii->i'
check(s, x)
def test_one_operand_10(self):
x = rng().randn(3, 3, 4)
s = 'iij->i'
check(s, x)
def test_one_operand_11(self):
x = rng().randn(3, 3, 3)
s = 'iii->i'
check(s, x)
def test_one_operand_12(self):
x = rng().randn(3, 3, 5, 4, 4)
s = 'iijkk->i'
check(s, x)
def test_one_operand_13(self):
x = rng().randn(3, 3, 5, 4, 4)
s = 'iijkk->ik'
check(s, x)
def test_one_operand_14(self):
x = rng().randn(3, 3, 5, 4, 4)
s = 'iijkl->il'
check(s, x)
# def test_one_operand_10(self):
# x = rng().randn(3, 3, 3)
# s = 'iii->i'
# check(s, x)
if __name__ == '__main__':
absltest.main()