mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make it possible to vmap xmapped functions
Or perhaps more importantly make it possible to nest xmaps that don't specify any `axis_resources`. The math is a little tricky, so I've added a fairly strong test that enumerates a wide range of potential ways of interleaving vmapped and xmapped axes in both inputs and the output. Thanks to that, I've actually caught one very subtle bug in the dynamic tracing rule for xmap (sorting by dimension names instead of positional axes).
This commit is contained in:
parent
040d268cf8
commit
6d2b307ced
@ -31,9 +31,10 @@ from ..api_util import flatten_fun_nokwargs, flatten_axes
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from .._src.util import safe_map, safe_zip, HashableFunction
|
||||
from .._src.util import safe_map, safe_zip, HashableFunction, as_hashable_function, unzip2
|
||||
from .._src.lax.parallel import _axis_index_translation_rule
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -252,7 +253,7 @@ def xmap(fun: Callable,
|
||||
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
|
||||
out_flat = xmap_p.bind(
|
||||
fun_flat, *args_flat,
|
||||
name=fun.__name__,
|
||||
name=getattr(fun, '__name__', '<unnamed function>'),
|
||||
in_axes=tuple(in_axes_flat),
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
axis_sizes=FrozenDict(axis_sizes),
|
||||
@ -305,7 +306,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
*in_avals)
|
||||
else:
|
||||
# We have to trace again, because `f` is a linear function, so we can't just return it.
|
||||
final_jaxpr, _, final_consts = pe.trace_to_jaxpr_final(f, in_avals)
|
||||
final_jaxpr, out_avals, final_consts = pe.trace_to_jaxpr_final(f, in_avals)
|
||||
return core.jaxpr_as_fun(core.ClosedJaxpr(final_jaxpr, final_consts))
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
@ -403,6 +404,39 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
||||
|
||||
def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
not_mapped = batching.not_mapped
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
assert primitive is xmap_p
|
||||
if all(dim is not_mapped for dim in dims):
|
||||
return primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
||||
def fmap_dims(axes, f):
|
||||
return AxisNamePos((name, f(axis)) for name, axis in axes.items())
|
||||
new_in_axes = tuple(
|
||||
fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
|
||||
for d, in_axes in zip(dims, params['in_axes']))
|
||||
new_dims = tuple(
|
||||
d if d is not_mapped else d - sum(a < d for a in in_axis.values())
|
||||
for d, in_axis in zip(dims, params['in_axes']))
|
||||
f, dims_out = batching.batch_subtrace(f, self.main, new_dims)
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
# NOTE: This assumes that the choice of the dimensions over which outputs
|
||||
# are batched is entirely dependent on the function and not e.g. on the
|
||||
# data or its shapes.
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
return tuple(
|
||||
fmap_dims(out_axes, lambda a: a + (d is not not_mapped and d <= a))
|
||||
for out_axes, d in zip(out_axes_thunk(), dims_out()))
|
||||
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
||||
vals_out = primitive.bind(f, *vals, **new_params)
|
||||
dims_out = tuple(d if d is not_mapped else d + sum(a < d for a in out_axes.values())
|
||||
for d, out_axes in zip(dims_out(), out_axes_thunk()))
|
||||
return [batching.BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
||||
batching.BatchTrace.process_xmap = _batch_trace_process_xmap # type: ignore
|
||||
|
||||
|
||||
# -------- nested xmap handling --------
|
||||
|
||||
@ -537,7 +571,7 @@ def _delete_aval_axes(aval, axes: AxisNamePos):
|
||||
def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
shape = list(aval.shape)
|
||||
for name, axis in sorted(axes.items()):
|
||||
for name, axis in sorted(axes.items(), key=lambda x: x[1]):
|
||||
shape.insert(axis, axis_sizes[name])
|
||||
return core.ShapedArray(tuple(shape), aval.dtype)
|
||||
|
||||
|
@ -18,6 +18,7 @@ import functools
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
from itertools import product, permutations
|
||||
from unittest import SkipTest, skip, skipIf
|
||||
|
||||
import numpy as np
|
||||
@ -62,7 +63,6 @@ def tearDownModule():
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
@curry
|
||||
def with_mesh(named_shape, f):
|
||||
if not named_shape:
|
||||
@ -178,22 +178,6 @@ class XMapTest(jtu.JaxTestCase):
|
||||
python_should_be_executing = False
|
||||
fm(x)
|
||||
|
||||
@skip("Need to implement vmap(xmap)")
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testNestedVectorize(self):
|
||||
@partial(xmap, in_axes=[None, 'a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
y = x * 2
|
||||
@partial(xmap, in_axes=['b', ...], out_axes=[None, 'b', ...])
|
||||
def h(y):
|
||||
return jnp.sin(y)
|
||||
return h(y)
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
self.assertAllClose(f(x),
|
||||
jnp.sin(x * 2).transpose((1, 2, 0)))
|
||||
|
||||
@skip("Need to implement vmap(xmap)")
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2), ('y', 3)])
|
||||
@ -255,6 +239,89 @@ class XMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_mapped(x, x), expected)
|
||||
run_test()
|
||||
|
||||
def VmapOfXmapCases():
|
||||
xmap_in_axes = ([{}] +
|
||||
[{i: 'x'} for i in range(3)] +
|
||||
[{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j])
|
||||
for xmap_dim_x, xmap_dim_y in product(xmap_in_axes, repeat=2):
|
||||
xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values()))
|
||||
num_axes = len(xmap_axes)
|
||||
if xmap_axes is None:
|
||||
continue
|
||||
xmap_out_axes = [dict(zip(dims, xmap_axes))
|
||||
for dims in permutations(range(2 + num_axes), num_axes)]
|
||||
for xmap_dim_z in xmap_out_axes:
|
||||
for vmap_dim_x in [*range(2 + len(xmap_dim_x)), None]:
|
||||
for vmap_dim_y in [*range(2 + len(xmap_dim_y)), None]:
|
||||
if vmap_dim_x is None and vmap_dim_y is None:
|
||||
continue
|
||||
for vmap_dim_z in range(2 + len(xmap_axes)):
|
||||
for vmap_as_xmap in [False, True]:
|
||||
yield {"testcase_name":
|
||||
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
|
||||
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
|
||||
f"vout={vmap_dim_z}_vmap_as_xmap={vmap_as_xmap}",
|
||||
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
|
||||
"xmap_out_axes": xmap_dim_z,
|
||||
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
|
||||
"vmap_out_axes": vmap_dim_z,
|
||||
"vmap_as_xmap": vmap_as_xmap}
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(VmapOfXmapCases()))
|
||||
@ignore_xmap_warning()
|
||||
def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_as_xmap):
|
||||
"""Test various vmap(xmap) and xmap(xmap) combinations.
|
||||
|
||||
The outer map always introduces a single dimension, the inner map introduces one or two.
|
||||
"""
|
||||
(xin_x, xin_y) = xmap_in_axes
|
||||
(vin_x, vin_y) = vmap_in_axes
|
||||
vmap_size = 7
|
||||
xmap_sizes = {'x': 11, 'y': 13}
|
||||
|
||||
xshape = [2, 3]
|
||||
yshape = [3, 5]
|
||||
zshape = [2, 5]
|
||||
xind = ['n', 'k']
|
||||
yind = ['k', 'm']
|
||||
zind = ['n', 'm']
|
||||
f = partial(jnp.einsum, 'nk,km->nm')
|
||||
|
||||
for pos, name in sorted(xin_x.items()):
|
||||
xshape.insert(pos, xmap_sizes[name])
|
||||
xind.insert(pos, name)
|
||||
for pos, name in sorted(xin_y.items()):
|
||||
yshape.insert(pos, xmap_sizes[name])
|
||||
yind.insert(pos, name)
|
||||
for pos, name in sorted(xmap_out_axes.items()):
|
||||
zshape.insert(pos, xmap_sizes[name])
|
||||
zind.insert(pos, name)
|
||||
|
||||
if vin_x is not None:
|
||||
xshape.insert(vin_x, vmap_size)
|
||||
xind.insert(vin_x, 'v')
|
||||
if vin_y is not None:
|
||||
yshape.insert(vin_y, vmap_size)
|
||||
yind.insert(vin_y, 'v')
|
||||
zshape.insert(vmap_out_axes, vmap_size)
|
||||
zind.insert(vmap_out_axes, 'v')
|
||||
|
||||
if vmap_as_xmap:
|
||||
do_vmap = partial(xmap,
|
||||
in_axes=({vin_x: 'v'} if vin_x is not None else {},
|
||||
{vin_y: 'v'} if vin_y is not None else {}),
|
||||
out_axes={vmap_out_axes: 'v'})
|
||||
else:
|
||||
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes)
|
||||
|
||||
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
|
||||
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(*xshape)
|
||||
y = rng.randn(*yshape)
|
||||
self.assertAllClose(fm(x, y), fref(x, y))
|
||||
|
||||
|
||||
class XMapTestSPMD(XMapTest):
|
||||
"""Re-executes all tests with the SPMD partitioner enabled"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user