From a147046d18f12de464f4dce840e306bfa5a7e165 Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Tue, 31 Aug 2021 19:36:18 -0700 Subject: [PATCH] Add unary xeinsum and allow named axis reductions for unary and binary xeinsums --- jax/_src/lax/parallel.py | 94 ++++++++++++++++++------------ jax/_src/numpy/lax_numpy.py | 3 +- tests/xmap_test.py | 110 ++++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 37 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 7c302f54e..d31ebb3d3 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -36,7 +36,7 @@ from jax._src.lax import lax from jax._src.lax import slicing from jax._src.numpy import lax_numpy import jax._src.util as util -from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis +from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import mhlo @@ -419,50 +419,74 @@ def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ()), precision=lax.canonicalize_precision(precision)) -def xeinsum(spec: str, x, y): +def xeinsum(spec: str, *operands): in_spec, out_spec = spec.split('->') - (lhs_subs, lhs_named), (rhs_subs, rhs_named) = XeinsumSpecParser(in_spec).parse_args() + all_in_subs, all_in_named = unzip2(XeinsumSpecParser(in_spec).parse_args()) (out_subs, out_named), = XeinsumSpecParser(out_spec).parse_args() - all_named = {*lhs_named, *rhs_named, *out_named} - all_subs = {*lhs_subs, *rhs_subs, *out_subs} - lhs_uniques = set(lhs_subs) - set(rhs_subs) - rhs_uniques = set(rhs_subs) - set(lhs_subs) - if all_subs & all_named: - raise NotImplementedError - if not set(out_named).issubset({*lhs_named, *rhs_named}): - raise ValueError - # if a named axis appears in both inputs and not the output, contract! - named_contract = list(all_named - set(out_named)) + if len(operands) != len(all_in_named): + raise ValueError("Expecting the same number of argument specs in the " + "subscript ({in_spec}) as the number of operands. But got " + "{len(all_in_named)} argument specs for " + "{len(operands)} operands") - # if a subscript appears in both inputs and not the outputs, contract! - subs_contract = all_subs - set(out_subs) + if len(operands) > 2: + raise NotImplementedError("Only one or two operands are supported. " + f"But got {len(operands)} operands") - lhs_reduce_axes = [lhs_subs.index(n) for n in lhs_uniques & subs_contract] - if lhs_reduce_axes: - x = lax._reduce_sum(x, lhs_reduce_axes) - for i in sorted(lhs_reduce_axes, reverse=True): - del lhs_subs[i] + # output subs and named axes must appear in at least one of the inputs. + if not set(out_named).issubset(set().union(*all_in_named)): + raise ValueError("Found named axes " + f"{set(out_named) - set().union(*all_in_named)} " + "appearing in the output spec but not in the input") + if not set(out_subs).issubset(set().union(*all_in_subs)): + raise ValueError("Found subscript(s) " + f"{set(out_subs) - set().union(*all_in_subs)} " + "appearing in the output spec but not in the input") - rhs_reduce_axes = [rhs_subs.index(n) for n in rhs_uniques & subs_contract] - if rhs_reduce_axes: - y = lax._reduce_sum(y, rhs_reduce_axes) - for i in sorted(rhs_reduce_axes, reverse=True): - del rhs_subs[i] + xs = list(operands) + for idx, (in_subs, in_named) in enumerate(safe_zip(all_in_subs, all_in_named)): + # if a subscript axis appears only in one input and not the output, reduce! + other_named = set().union( # type: ignore + *[named for i, named in enumerate(all_in_named) if i != idx]) + other_subs = set().union( # type: ignore + *[subs for i, subs in enumerate(all_in_subs) if i != idx]) - pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n)) - for n in subs_contract - (lhs_uniques | rhs_uniques)) + subs_reduce = list(set(in_subs) - {*out_subs, *other_subs}) + subs_reduce_axes = [in_subs.index(n) for n in subs_reduce] + named_reduce_axes = list(set(in_named) - {*out_named, *other_named}) - # if a subscript apperas in both inputs _and_ the outputs, batch! - subs_batch = all_subs - subs_contract - if subs_batch & (lhs_uniques | rhs_uniques): - raise NotImplementedError + if subs_reduce_axes or named_reduce_axes: + xs[idx] = psum(xs[idx], axis_name=subs_reduce_axes + named_reduce_axes) + for i in sorted(subs_reduce_axes, reverse=True): + del all_in_subs[idx][i] + for named_axis in named_reduce_axes: + all_in_named[idx].remove(named_axis) - pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) - for n in subs_batch) + if len(operands) == 1: + return xs[0] + + if len(operands) == 2: + x, y = xs + lhs_subs, rhs_subs = all_in_subs + lhs_named, rhs_named = all_in_named + + # if a named axis appears in both inputs and not the output, contract! + named_contract = list((set(lhs_named) & set(rhs_named)) - set(out_named)) + + # if a subscript appears in both inputs and not the outputs, contract! + subs_contract = (set(lhs_subs) & set(rhs_subs)) - set(out_subs) + + pos_contract = unzip2((lhs_subs.index(n), rhs_subs.index(n)) + for n in subs_contract) + + # if a subscript appears in both inputs _and_ the outputs, batch! + subs_batch = (set(lhs_subs) & set(rhs_subs)) - subs_contract + pos_batch = unzip2((lhs_subs.index(n), rhs_subs.index(n)) for n in subs_batch) + + return pdot(x, y, axis_name=named_contract, + pos_contract=pos_contract, pos_batch=pos_batch) - return pdot(x, y, axis_name=named_contract, - pos_contract=pos_contract, pos_batch=pos_batch) class XeinsumSpecParser: spec: str diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b401f482e..83025e6ac 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2826,8 +2826,7 @@ def einsum(*operands, out=None, optimize='optimal', precision=None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") - if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and - len(operands[1:]) == 2): + if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0]): return lax.xeinsum(*operands) optimize = 'optimal' if optimize is True else optimize diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 5c31405e2..c65a7f442 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -1332,6 +1332,15 @@ class PDotTests(XMapTestCase): expected = np.einsum('ij,ij->i', x, y) self.assertAllClose(out, expected, check_dtypes=True) + def test_xeinsum_no_named_axes_batch_matmul(self): + rng = np.random.RandomState(0) + x = rng.randn(3, 5, 4) + y = rng.randn(3, 4, 2) + out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True) + expected = np.einsum('bij,bjk->bik', x, y) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol) + def test_xeinsum_no_named_axes_reduce_sum(self): rng = self.rng() x = rng.randn(3) @@ -1341,6 +1350,107 @@ class PDotTests(XMapTestCase): self.assertAllClose(out, expected, check_dtypes=True) + def test_xeinsum_no_named_axes_reduce_and_contract(self): + rng = np.random.RandomState(0) + x = rng.randn(3, 5, 4) + y = rng.randn(2, 4, 2) + out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True) + expected = np.einsum('bij,cjk->ik', x, y) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol) + + def test_xeinsum_named_axes_reduce(self): + rng = np.random.RandomState(0) + x = rng.randn(3, 4) + y = rng.randn(5,) + + def check(spec): + out = xmap(partial(jnp.einsum, spec), + in_axes=(['i', 'j'], ['k']), + out_axes=['i', 'k'])(x, y) + expected = np.einsum('ij,k->ik', x, y) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, + atol=tol, rtol=tol) + check('{i,j},{k}->{i,k}') + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_xeinsum_named_axes_reduce_with_mesh(self): + rng = np.random.RandomState(0) + x = rng.randn(6, 4) + y = rng.randn(8,) + + def check(spec): + out = xmap(partial(jnp.einsum, spec), + in_axes=(['i', 'j'], ['k']), + out_axes=['i', 'k'], + axis_resources={'i': 'x', 'k': 'y'})(x, y) + expected = np.einsum('ij,k->ik', x, y) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, + atol=tol, rtol=tol) + + check('{i,j},{k}->{i,k}') + check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter! + check('{j,i},{k}->{i,k}') + check('{j,i},{k}->{k,i}') + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_xeinsum_named_axes_batch_matmul_with_mesh(self): + rng = np.random.RandomState(0) + x = rng.randn(8, 3, 4) + y = rng.randn(8, 4, 5) + + def check(spec): + out = xmap(partial(jnp.einsum, spec), + in_axes=(['b', 'i', 'j'], ['b', 'j', 'k']), + out_axes=['b', 'i', 'k'], + axis_resources={'b': 'x', 'j': 'y'})(x, y) + expected = np.einsum('bij,bjk->bik', x, y) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, + atol=tol, rtol=tol) + + check('{b,i,j},{b,j,k}->{b,i,k}') + check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter! + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_xeinsum_named_axes_unary_reduce_with_mesh(self): + rng = np.random.RandomState(0) + x = rng.randn(8, 6, 4) + + def check(spec): + out = xmap(partial(jnp.einsum, spec), + in_axes=['b', 'i', 'j'], + out_axes=['b'], + axis_resources={'b': 'x', 'i': 'y'})(x) + expected = np.einsum('bij->b', x) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, + atol=tol, rtol=tol) + + check('{b,i,j}->{b}') + check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter! + check('{i,j,b}->{b}') + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self): + rng = np.random.RandomState(0) + x = rng.randn(8, 6, 4, 5) + + def check(spec): + out = xmap(partial(jnp.einsum, spec), + in_axes=['b', 'i', ...], + out_axes=['b', ...], + axis_resources={'b': 'x', 'i': 'y'})(x) + expected = np.einsum('bijk->bk', x) + tol = 1e-1 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(out, expected, check_dtypes=True, + atol=tol, rtol=tol) + + check('jk{i,b}->k{b}') + + class XMapErrorTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2)])