Minor fixes for xmap docstring, xeinsum parser

The regression loss example from the xmap docstring was broken and
the xeinsum parser didn't accept empty parens while it should.
This commit is contained in:
Adam Paszke 2021-02-09 17:01:26 +00:00
parent 01485ec4ee
commit 926b2ad03f
4 changed files with 35 additions and 32 deletions

View File

@ -451,8 +451,8 @@ class XeinsumSpecParser:
axis_name = self.spec[self.pos:end]
assert axis_name
self.pos = end + 1
return axis_name, self.spec[end] == ','
self.pos = end
return axis_name
def maybe_take(self, char: str, on_eof: bool = False):
if self.eof:
@ -474,10 +474,15 @@ class XeinsumSpecParser:
return True, (subscripts, names)
else:
assert self.maybe_take('{')
while True:
axis_name, cont = self.parse_axis_name()
first = True
while not self.maybe_take('}'):
if not first:
assert self.maybe_take(',')
first = False
if self.eof:
raise ValueError("Unterminated named axis brace")
axis_name = self.parse_axis_name()
names.append(axis_name)
if not cont: break
return self.maybe_take(',', False), (subscripts, names)
def parse_args(self):

View File

@ -1994,7 +1994,7 @@ alltrue = all
sometrue = any
def _axis_size(a, axis):
if not isinstance(axis, collections.abc.Sequence):
if not isinstance(axis, (tuple, list)):
axis = (axis,)
size = 1
a_shape = shape(a)

View File

@ -367,15 +367,15 @@ def xmap(fun: Callable,
while named axes are just a convenient way to achieve batching. While this
might seem like a silly example at first, it might turn out to be useful in
practice, since with conjuction with ``axis_resources`` this makes it possible
to implement a distributed matrix-multiplication in just a few lines of code:
to implement a distributed matrix-multiplication in just a few lines of code::
>>> devices = np.array(jax.devices())[:4].reshape((2, 2))
>>> with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
... distributed_out = xmap(
... jnp.vdot,
... in_axes=({0: 'left', 1: 'right'}),
... out_axes=['left', 'right', ...],
... axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
distributed_out = xmap(
jnp.vdot,
in_axes=({0: 'left'}, {1: 'right'}),
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
Still, the above examples are quite simple. After all, the xmapped
computation was a simple NumPy function that didn't use the axis names at all!
@ -384,8 +384,9 @@ def xmap(fun: Callable,
def regression_loss(x, y, w, b):
# Contract over in_features. Batch and out_features are present in
# both inputs and output, so they don't need to be mentioned
y_pred = jnp.einsum('{in_features},{in_features}->{}') + b
return jnp.mean((y - y_pred) ** 2, axis='batch')
y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b
error = jnp.sum((y - y_pred) ** 2, axis='out_features')
return jnp.mean(error, axis='batch')
xmap(regression_loss,
in_axes=(['batch', 'in_features', ...],

View File

@ -964,22 +964,19 @@ class PDotTests(XMapTestCase):
x = rng.randn(3, 4)
y = rng.randn(4, 5)
out = xmap(partial(jnp.einsum, '{i,j},{j,k}->{i,k}'),
in_axes=(['i', 'j'], ['j', 'k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,jk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
# order of named axes in the spec doesn't matter!
out = xmap(partial(jnp.einsum, '{i,j},{k,j}->{k,i}'),
in_axes=(['i', 'j'], ['j', 'k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,jk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['j', 'k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,jk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{i,j},{j,k}->{i,k}')
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j},{k,j}->{k}')
check('{i,j},{j}->{i}')
check('{j},{j}->{}')
def test_xeinsum_no_named_axes_vector_dot(self):
rng = np.random.RandomState(0)