mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Minor fixes for xmap docstring, xeinsum parser
The regression loss example from the xmap docstring was broken and the xeinsum parser didn't accept empty parens while it should.
This commit is contained in:
parent
01485ec4ee
commit
926b2ad03f
@ -451,8 +451,8 @@ class XeinsumSpecParser:
|
||||
|
||||
axis_name = self.spec[self.pos:end]
|
||||
assert axis_name
|
||||
self.pos = end + 1
|
||||
return axis_name, self.spec[end] == ','
|
||||
self.pos = end
|
||||
return axis_name
|
||||
|
||||
def maybe_take(self, char: str, on_eof: bool = False):
|
||||
if self.eof:
|
||||
@ -474,10 +474,15 @@ class XeinsumSpecParser:
|
||||
return True, (subscripts, names)
|
||||
else:
|
||||
assert self.maybe_take('{')
|
||||
while True:
|
||||
axis_name, cont = self.parse_axis_name()
|
||||
first = True
|
||||
while not self.maybe_take('}'):
|
||||
if not first:
|
||||
assert self.maybe_take(',')
|
||||
first = False
|
||||
if self.eof:
|
||||
raise ValueError("Unterminated named axis brace")
|
||||
axis_name = self.parse_axis_name()
|
||||
names.append(axis_name)
|
||||
if not cont: break
|
||||
return self.maybe_take(',', False), (subscripts, names)
|
||||
|
||||
def parse_args(self):
|
||||
|
@ -1994,7 +1994,7 @@ alltrue = all
|
||||
sometrue = any
|
||||
|
||||
def _axis_size(a, axis):
|
||||
if not isinstance(axis, collections.abc.Sequence):
|
||||
if not isinstance(axis, (tuple, list)):
|
||||
axis = (axis,)
|
||||
size = 1
|
||||
a_shape = shape(a)
|
||||
|
@ -367,15 +367,15 @@ def xmap(fun: Callable,
|
||||
while named axes are just a convenient way to achieve batching. While this
|
||||
might seem like a silly example at first, it might turn out to be useful in
|
||||
practice, since with conjuction with ``axis_resources`` this makes it possible
|
||||
to implement a distributed matrix-multiplication in just a few lines of code:
|
||||
to implement a distributed matrix-multiplication in just a few lines of code::
|
||||
|
||||
>>> devices = np.array(jax.devices())[:4].reshape((2, 2))
|
||||
>>> with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
||||
... distributed_out = xmap(
|
||||
... jnp.vdot,
|
||||
... in_axes=({0: 'left', 1: 'right'}),
|
||||
... out_axes=['left', 'right', ...],
|
||||
... axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
|
||||
devices = np.array(jax.devices())[:4].reshape((2, 2))
|
||||
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
||||
distributed_out = xmap(
|
||||
jnp.vdot,
|
||||
in_axes=({0: 'left'}, {1: 'right'}),
|
||||
out_axes=['left', 'right', ...],
|
||||
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
|
||||
|
||||
Still, the above examples are quite simple. After all, the xmapped
|
||||
computation was a simple NumPy function that didn't use the axis names at all!
|
||||
@ -384,8 +384,9 @@ def xmap(fun: Callable,
|
||||
def regression_loss(x, y, w, b):
|
||||
# Contract over in_features. Batch and out_features are present in
|
||||
# both inputs and output, so they don't need to be mentioned
|
||||
y_pred = jnp.einsum('{in_features},{in_features}->{}') + b
|
||||
return jnp.mean((y - y_pred) ** 2, axis='batch')
|
||||
y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b
|
||||
error = jnp.sum((y - y_pred) ** 2, axis='out_features')
|
||||
return jnp.mean(error, axis='batch')
|
||||
|
||||
xmap(regression_loss,
|
||||
in_axes=(['batch', 'in_features', ...],
|
||||
|
@ -964,22 +964,19 @@ class PDotTests(XMapTestCase):
|
||||
x = rng.randn(3, 4)
|
||||
y = rng.randn(4, 5)
|
||||
|
||||
out = xmap(partial(jnp.einsum, '{i,j},{j,k}->{i,k}'),
|
||||
in_axes=(['i', 'j'], ['j', 'k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,jk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
# order of named axes in the spec doesn't matter!
|
||||
out = xmap(partial(jnp.einsum, '{i,j},{k,j}->{k,i}'),
|
||||
in_axes=(['i', 'j'], ['j', 'k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,jk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
def check(spec):
|
||||
out = xmap(partial(jnp.einsum, spec),
|
||||
in_axes=(['i', 'j'], ['j', 'k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,jk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
check('{i,j},{j,k}->{i,k}')
|
||||
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
|
||||
check('{j},{k,j}->{k}')
|
||||
check('{i,j},{j}->{i}')
|
||||
check('{j},{j}->{}')
|
||||
|
||||
def test_xeinsum_no_named_axes_vector_dot(self):
|
||||
rng = np.random.RandomState(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user