Merge pull request #7772 from bloops:xeinsum

PiperOrigin-RevId: 444535710
This commit is contained in:
jax authors 2022-04-26 05:58:55 -07:00
commit 04b6f15cdb
3 changed files with 170 additions and 37 deletions

View File

@ -36,7 +36,7 @@ from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.numpy import lax_numpy
import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
@ -419,50 +419,74 @@ def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()),
precision=lax.canonicalize_precision(precision))
def xeinsum(spec: str, x, y):
def xeinsum(spec: str, *operands):
in_spec, out_spec = spec.split('->')
(lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args()
all_in_subs, all_in_named = unzip2(XeinsumSpecParser(in_spec).parse_args())
(out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args()
all_named = {*lhs_named, *rhs_named, *out_named}
all_subs = {*lhs_subs, *rhs_subs, *out_subs}
lhs_uniques = set(lhs_subs) - set(rhs_subs)
rhs_uniques = set(rhs_subs) - set(lhs_subs)
if all_subs & all_named:
raise NotImplementedError
if not set(out_named).issubset({*lhs_named, *rhs_named}):
raise ValueError
# if a named axis appears in both inputs and not the output, contract!
named_contract = list(all_named - set(out_named))
if len(operands) != len(all_in_named):
raise ValueError("Expecting the same number of argument specs in the "
"subscript ({in_spec}) as the number of operands. But got "
"{len(all_in_named)} argument specs for "
"{len(operands)} operands")
# if a subscript appears in both inputs and not the outputs, contract!
subs_contract = all_subs - set(out_subs)
if len(operands) > 2:
raise NotImplementedError("Only one or two operands are supported. "
f"But got {len(operands)} operands")
lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract]
if lhs_reduce_axes:
x = lax._reduce_sum(x, lhs_reduce_axes)
for i in sorted(lhs_reduce_axes, reverse=True):
del lhs_subs[i]
# output subs and named axes must appear in at least one of the inputs.
if not set(out_named).issubset(set().union(*all_in_named)):
raise ValueError("Found named axes "
f"{set(out_named) - set().union(*all_in_named)} "
"appearing in the output spec but not in the input")
if not set(out_subs).issubset(set().union(*all_in_subs)):
raise ValueError("Found subscript(s) "
f"{set(out_subs) - set().union(*all_in_subs)} "
"appearing in the output spec but not in the input")
rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract]
if rhs_reduce_axes:
y = lax._reduce_sum(y, rhs_reduce_axes)
for i in sorted(rhs_reduce_axes, reverse=True):
del rhs_subs[i]
xs = list(operands)
for idx, (in_subs, in_named) in enumerate(safe_zip(all_in_subs, all_in_named)):
# if a subscript axis appears only in one input and not the output, reduce!
other_named = set().union( # type: ignore
*[named for i, named in enumerate(all_in_named) if i != idx])
other_subs = set().union( # type: ignore
*[subs for i, subs in enumerate(all_in_subs) if i != idx])
pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_contract - (lhs_uniques | rhs_uniques))
subs_reduce = list(set(in_subs) - {*out_subs, *other_subs})
subs_reduce_axes = [in_subs.index(n) for n in subs_reduce]
named_reduce_axes = list(set(in_named) - {*out_named, *other_named})
# if a subscript apperas in both inputs _and_ the outputs, batch!
subs_batch = all_subs - subs_contract
if subs_batch & (lhs_uniques | rhs_uniques):
raise NotImplementedError
if subs_reduce_axes or named_reduce_axes:
xs[idx] = psum(xs[idx], axis_name=subs_reduce_axes + named_reduce_axes)
for i in sorted(subs_reduce_axes, reverse=True):
del all_in_subs[idx][i]
for named_axis in named_reduce_axes:
all_in_named[idx].remove(named_axis)
pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_batch)
if len(operands) == 1:
return xs[0]
if len(operands) == 2:
x, y = xs
lhs_subs, rhs_subs = all_in_subs
lhs_named, rhs_named = all_in_named
# if a named axis appears in both inputs and not the output, contract!
named_contract = list((set(lhs_named) & set(rhs_named)) - set(out_named))
# if a subscript appears in both inputs and not the outputs, contract!
subs_contract = (set(lhs_subs) & set(rhs_subs)) - set(out_subs)
pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
for n in subs_contract)
# if a subscript appears in both inputs _and_ the outputs, batch!
subs_batch = (set(lhs_subs) & set(rhs_subs)) - subs_contract
pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) for n in subs_batch)
return pdot(x, y, axis_name=named_contract,
pos_contract=pos_contract, pos_batch=pos_batch)
return pdot(x, y, axis_name=named_contract,
pos_contract=pos_contract, pos_batch=pos_batch)
class XeinsumSpecParser:
spec: str

View File

@ -2826,8 +2826,7 @@ def einsum(*operands, out=None, optimize='optimal', precision=None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and
len(operands[1:]) == 2):
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0]):
return lax.xeinsum(*operands)
optimize = 'optimal' if optimize is True else optimize

View File

@ -1332,6 +1332,15 @@ class PDotTests(XMapTestCase):
expected = np.einsum('ij,ij->i', x, y)
self.assertAllClose(out, expected, check_dtypes=True)
def test_xeinsum_no_named_axes_batch_matmul(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(3, 4, 2)
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
def test_xeinsum_no_named_axes_reduce_sum(self):
rng = self.rng()
x = rng.randn(3)
@ -1341,6 +1350,107 @@ class PDotTests(XMapTestCase):
self.assertAllClose(out, expected, check_dtypes=True)
def test_xeinsum_no_named_axes_reduce_and_contract(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(2, 4, 2)
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,cjk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
def test_xeinsum_named_axes_reduce(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 4)
y = rng.randn(5,)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{i,j},{k}->{i,k}')
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(6, 4)
y = rng.randn(8,)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['k']),
out_axes=['i', 'k'],
axis_resources={'i': 'x', 'k': 'y'})(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{i,j},{k}->{i,k}')
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j,i},{k}->{i,k}')
check('{j,i},{k}->{k,i}')
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 3, 4)
y = rng.randn(8, 4, 5)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['b', 'i', 'j'], ['b', 'j', 'k']),
out_axes=['b', 'i', 'k'],
axis_resources={'b': 'x', 'j': 'y'})(x, y)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{b,i,j},{b,j,k}->{b,i,k}')
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=['b', 'i', 'j'],
out_axes=['b'],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bij->b', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{b,i,j}->{b}')
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
check('{i,j,b}->{b}')
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4, 5)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=['b', 'i', ...],
out_axes=['b', ...],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bijk->bk', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('jk{i,b}->k{b}')
class XMapErrorTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 2)])