mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add comment marking a bug
This commit is contained in:
parent
166f45bf2b
commit
9a68bce567
@ -1082,7 +1082,7 @@ def _einsum(operands, contractions):
|
||||
input_names = input_str.split(',')
|
||||
|
||||
# switch on the number of operands to be processed in this loop iteration.
|
||||
# every case here sets 'result' and 'names'.
|
||||
# every case here sets 'operand' and 'names'.
|
||||
if len(operand_indices) == 1:
|
||||
operand = operands.pop(operand_indices[0])
|
||||
names, = input_names
|
||||
@ -1120,7 +1120,7 @@ def _einsum(operands, contractions):
|
||||
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
|
||||
for n in batch_names)
|
||||
if contracted_names:
|
||||
# contract usint lax.dot_general
|
||||
# contract using lax.dot_general
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
|
||||
@ -1130,6 +1130,7 @@ def _einsum(operands, contractions):
|
||||
lhs = moveaxis(lhs, lhs_batch, batch_dims)
|
||||
rhs = moveaxis(rhs, rhs_batch, batch_dims)
|
||||
batch_names = ''.join(batch_names)
|
||||
# TODO(mattjj): may need to update lhs_cont and rhs_cont here
|
||||
else:
|
||||
batch_dims = tuple(lhs_batch)
|
||||
batch_names = ''.join(lhs_names[i] for i in batch_dims)
|
||||
|
@ -192,14 +192,13 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
s = 'ij->ij'
|
||||
check(s, x)
|
||||
|
||||
# TODO(mattjj): patch this up!
|
||||
# def test_tf_unsupported_1(self):
|
||||
# # from https://www.tensorflow.org/api_docs/python/tf/einsum
|
||||
# r = rng()
|
||||
# x = r.randn(2, 3, 5, 1)
|
||||
# y = r.randn(3, 4, 5, 1)
|
||||
# s = 'ij...,jk...->ik...'
|
||||
# check(s, x, y)
|
||||
def test_tf_unsupported_1(self):
|
||||
# from https://www.tensorflow.org/api_docs/python/tf/einsum
|
||||
r = rng()
|
||||
x = r.randn(2, 3, 5, 1)
|
||||
y = r.randn(3, 4, 5, 1)
|
||||
s = 'ij...,jk...->ik...'
|
||||
check(s, x, y)
|
||||
|
||||
def test_tf_unsupported_2(self):
|
||||
# from https://www.tensorflow.org/api_docs/python/tf/einsum
|
||||
|
Loading…
x
Reference in New Issue
Block a user