mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add unary xeinsum and allow named axis reductions for unary and binary xeinsums
This commit is contained in:
parent
098f2126ae
commit
a147046d18
@ -36,7 +36,7 @@ from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.numpy import lax_numpy
|
||||
import jax._src.util as util
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
|
||||
@ -419,50 +419,74 @@ def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()),
|
||||
precision=lax.canonicalize_precision(precision))
|
||||
|
||||
|
||||
def xeinsum(spec: str, x, y):
|
||||
def xeinsum(spec: str, *operands):
|
||||
in_spec, out_spec = spec.split('->')
|
||||
(lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args()
|
||||
all_in_subs, all_in_named = unzip2(XeinsumSpecParser(in_spec).parse_args())
|
||||
(out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args()
|
||||
all_named = {*lhs_named, *rhs_named, *out_named}
|
||||
all_subs = {*lhs_subs, *rhs_subs, *out_subs}
|
||||
lhs_uniques = set(lhs_subs) - set(rhs_subs)
|
||||
rhs_uniques = set(rhs_subs) - set(lhs_subs)
|
||||
if all_subs & all_named:
|
||||
raise NotImplementedError
|
||||
if not set(out_named).issubset({*lhs_named, *rhs_named}):
|
||||
raise ValueError
|
||||
|
||||
# if a named axis appears in both inputs and not the output, contract!
|
||||
named_contract = list(all_named - set(out_named))
|
||||
if len(operands) != len(all_in_named):
|
||||
raise ValueError("Expecting the same number of argument specs in the "
|
||||
"subscript ({in_spec}) as the number of operands. But got "
|
||||
"{len(all_in_named)} argument specs for "
|
||||
"{len(operands)} operands")
|
||||
|
||||
# if a subscript appears in both inputs and not the outputs, contract!
|
||||
subs_contract = all_subs - set(out_subs)
|
||||
if len(operands) > 2:
|
||||
raise NotImplementedError("Only one or two operands are supported. "
|
||||
f"But got {len(operands)} operands")
|
||||
|
||||
lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract]
|
||||
if lhs_reduce_axes:
|
||||
x = lax._reduce_sum(x, lhs_reduce_axes)
|
||||
for i in sorted(lhs_reduce_axes, reverse=True):
|
||||
del lhs_subs[i]
|
||||
# output subs and named axes must appear in at least one of the inputs.
|
||||
if not set(out_named).issubset(set().union(*all_in_named)):
|
||||
raise ValueError("Found named axes "
|
||||
f"{set(out_named) - set().union(*all_in_named)} "
|
||||
"appearing in the output spec but not in the input")
|
||||
if not set(out_subs).issubset(set().union(*all_in_subs)):
|
||||
raise ValueError("Found subscript(s) "
|
||||
f"{set(out_subs) - set().union(*all_in_subs)} "
|
||||
"appearing in the output spec but not in the input")
|
||||
|
||||
rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract]
|
||||
if rhs_reduce_axes:
|
||||
y = lax._reduce_sum(y, rhs_reduce_axes)
|
||||
for i in sorted(rhs_reduce_axes, reverse=True):
|
||||
del rhs_subs[i]
|
||||
xs = list(operands)
|
||||
for idx, (in_subs, in_named) in enumerate(safe_zip(all_in_subs, all_in_named)):
|
||||
# if a subscript axis appears only in one input and not the output, reduce!
|
||||
other_named = set().union( # type: ignore
|
||||
*[named for i, named in enumerate(all_in_named) if i != idx])
|
||||
other_subs = set().union( # type: ignore
|
||||
*[subs for i, subs in enumerate(all_in_subs) if i != idx])
|
||||
|
||||
pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
|
||||
for n in subs_contract - (lhs_uniques | rhs_uniques))
|
||||
subs_reduce = list(set(in_subs) - {*out_subs, *other_subs})
|
||||
subs_reduce_axes = [in_subs.index(n) for n in subs_reduce]
|
||||
named_reduce_axes = list(set(in_named) - {*out_named, *other_named})
|
||||
|
||||
# if a subscript apperas in both inputs _and_ the outputs, batch!
|
||||
subs_batch = all_subs - subs_contract
|
||||
if subs_batch & (lhs_uniques | rhs_uniques):
|
||||
raise NotImplementedError
|
||||
if subs_reduce_axes or named_reduce_axes:
|
||||
xs[idx] = psum(xs[idx], axis_name=subs_reduce_axes + named_reduce_axes)
|
||||
for i in sorted(subs_reduce_axes, reverse=True):
|
||||
del all_in_subs[idx][i]
|
||||
for named_axis in named_reduce_axes:
|
||||
all_in_named[idx].remove(named_axis)
|
||||
|
||||
pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n))
|
||||
for n in subs_batch)
|
||||
if len(operands) == 1:
|
||||
return xs[0]
|
||||
|
||||
if len(operands) == 2:
|
||||
x, y = xs
|
||||
lhs_subs, rhs_subs = all_in_subs
|
||||
lhs_named, rhs_named = all_in_named
|
||||
|
||||
# if a named axis appears in both inputs and not the output, contract!
|
||||
named_contract = list((set(lhs_named) & set(rhs_named)) - set(out_named))
|
||||
|
||||
# if a subscript appears in both inputs and not the outputs, contract!
|
||||
subs_contract = (set(lhs_subs) & set(rhs_subs)) - set(out_subs)
|
||||
|
||||
pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n))
|
||||
for n in subs_contract)
|
||||
|
||||
# if a subscript appears in both inputs _and_ the outputs, batch!
|
||||
subs_batch = (set(lhs_subs) & set(rhs_subs)) - subs_contract
|
||||
pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) for n in subs_batch)
|
||||
|
||||
return pdot(x, y, axis_name=named_contract,
|
||||
pos_contract=pos_contract, pos_batch=pos_batch)
|
||||
|
||||
return pdot(x, y, axis_name=named_contract,
|
||||
pos_contract=pos_contract, pos_batch=pos_batch)
|
||||
|
||||
class XeinsumSpecParser:
|
||||
spec: str
|
||||
|
@ -2826,8 +2826,7 @@ def einsum(*operands, out=None, optimize='optimal', precision=None,
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
|
||||
|
||||
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and
|
||||
len(operands[1:]) == 2):
|
||||
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0]):
|
||||
return lax.xeinsum(*operands)
|
||||
|
||||
optimize = 'optimal' if optimize is True else optimize
|
||||
|
@ -1332,6 +1332,15 @@ class PDotTests(XMapTestCase):
|
||||
expected = np.einsum('ij,ij->i', x, y)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
def test_xeinsum_no_named_axes_batch_matmul(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 5, 4)
|
||||
y = rng.randn(3, 4, 2)
|
||||
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
|
||||
expected = np.einsum('bij,bjk->bik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
||||
|
||||
def test_xeinsum_no_named_axes_reduce_sum(self):
|
||||
rng = self.rng()
|
||||
x = rng.randn(3)
|
||||
@ -1341,6 +1350,107 @@ class PDotTests(XMapTestCase):
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
|
||||
def test_xeinsum_no_named_axes_reduce_and_contract(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 5, 4)
|
||||
y = rng.randn(2, 4, 2)
|
||||
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
|
||||
expected = np.einsum('bij,cjk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
||||
|
||||
def test_xeinsum_named_axes_reduce(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 4)
|
||||
y = rng.randn(5,)
|
||||
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=(['i', 'j'], ['k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,k->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
check('{i,j},{k}->{i,k}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_xeinsum_named_axes_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(6, 4)
|
||||
y = rng.randn(8,)
|
||||
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=(['i', 'j'], ['k']),
|
||||
out_axes=['i', 'k'],
|
||||
axis_resources={'i': 'x', 'k': 'y'})(x, y)
|
||||
expected = np.einsum('ij,k->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
check('{i,j},{k}->{i,k}')
|
||||
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
|
||||
check('{j,i},{k}->{i,k}')
|
||||
check('{j,i},{k}->{k,i}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 3, 4)
|
||||
y = rng.randn(8, 4, 5)
|
||||
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=(['b', 'i', 'j'], ['b', 'j', 'k']),
|
||||
out_axes=['b', 'i', 'k'],
|
||||
axis_resources={'b': 'x', 'j': 'y'})(x, y)
|
||||
expected = np.einsum('bij,bjk->bik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
check('{b,i,j},{b,j,k}->{b,i,k}')
|
||||
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 6, 4)
|
||||
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=['b', 'i', 'j'],
|
||||
out_axes=['b'],
|
||||
axis_resources={'b': 'x', 'i': 'y'})(x)
|
||||
expected = np.einsum('bij->b', x)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
check('{b,i,j}->{b}')
|
||||
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
|
||||
check('{i,j,b}->{b}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 6, 4, 5)
|
||||
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=['b', 'i', ...],
|
||||
out_axes=['b', ...],
|
||||
axis_resources={'b': 'x', 'i': 'y'})(x)
|
||||
expected = np.einsum('bijk->bk', x)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
check('jk{i,b}->k{b}')
|
||||
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
|
Loading…
x
Reference in New Issue
Block a user