add comment marking a bug

This commit is contained in:
Matthew Johnson 2018-12-19 08:55:59 -08:00
parent 166f45bf2b
commit 9a68bce567
2 changed files with 10 additions and 10 deletions

View File

@ -1082,7 +1082,7 @@ def _einsum(operands, contractions):
input_names = input_str.split(',')
# switch on the number of operands to be processed in this loop iteration.
# every case here sets 'result' and 'names'.
# every case here sets 'operand' and 'names'.
if len(operand_indices) == 1:
operand = operands.pop(operand_indices[0])
names, = input_names
@ -1120,7 +1120,7 @@ def _einsum(operands, contractions):
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
for n in batch_names)
if contracted_names:
# contract usint lax.dot_general
# contract using lax.dot_general
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
@ -1130,6 +1130,7 @@ def _einsum(operands, contractions):
lhs = moveaxis(lhs, lhs_batch, batch_dims)
rhs = moveaxis(rhs, rhs_batch, batch_dims)
batch_names = ''.join(batch_names)
# TODO(mattjj): may need to update lhs_cont and rhs_cont here
else:
batch_dims = tuple(lhs_batch)
batch_names = ''.join(lhs_names[i] for i in batch_dims)

View File

@ -192,14 +192,13 @@ class EinsumTest(jtu.JaxTestCase):
s = 'ij->ij'
check(s, x)
# TODO(mattjj): patch this up!
# def test_tf_unsupported_1(self):
# # from https://www.tensorflow.org/api_docs/python/tf/einsum
# r = rng()
# x = r.randn(2, 3, 5, 1)
# y = r.randn(3, 4, 5, 1)
# s = 'ij...,jk...->ik...'
# check(s, x, y)
def test_tf_unsupported_1(self):
# from https://www.tensorflow.org/api_docs/python/tf/einsum
r = rng()
x = r.randn(2, 3, 5, 1)
y = r.randn(3, 4, 5, 1)
s = 'ij...,jk...->ik...'
check(s, x, y)
def test_tf_unsupported_2(self):
# from https://www.tensorflow.org/api_docs/python/tf/einsum