mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix batching formula of xmap
Turns out that once you insert multiple dimensions things become much more tricky than in the case of batching a one-dimensional map. Also strenghten our tests to make sure we don't depend too much on the semantics of the einsum batching rule.
This commit is contained in:
parent
5ed0633d54
commit
6884f21b60
@ -686,17 +686,24 @@ def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
|
||||
for d, in_axis in zip(dims, params['in_axes']))
|
||||
f, mapped_dims_out = batching.batch_subtrace(f, self.main, mapped_dims_in)
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
def axis_after_insertion(axis, inserted_named_axes):
|
||||
for inserted_axis in sorted(inserted_named_axes.values()):
|
||||
if inserted_axis >= axis:
|
||||
break
|
||||
axis += 1
|
||||
return axis
|
||||
# NOTE: This assumes that the choice of the dimensions over which outputs
|
||||
# are batched is entirely dependent on the function and not e.g. on the
|
||||
# data or its shapes.
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
return tuple(
|
||||
fmap_dims(out_axes, lambda a: a + (d is not not_mapped and d <= a))
|
||||
out_axes if d is not_mapped else
|
||||
fmap_dims(out_axes, lambda a, nd=axis_after_insertion(d, out_axes): a + (nd <= a))
|
||||
for out_axes, d in zip(out_axes_thunk(), mapped_dims_out()))
|
||||
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
||||
vals_out = primitive.bind(f, *vals, **new_params)
|
||||
dims_out = tuple(d if d is not_mapped else d + sum(a < d for a in out_axes.values())
|
||||
dims_out = tuple(d if d is not_mapped else axis_after_insertion(d, out_axes)
|
||||
for d, out_axes in zip(mapped_dims_out(), out_axes_thunk()))
|
||||
return [batching.BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
||||
batching.BatchTrace.process_xmap = _batch_trace_process_xmap # type: ignore
|
||||
|
@ -35,13 +35,14 @@ import jax.scipy as jscipy
|
||||
from jax import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import NamedShape
|
||||
from jax.experimental.maps import Mesh, mesh, xmap
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.util import curry, unzip2, split_list, prod
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching, pxla
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -112,6 +113,23 @@ def powerset(s):
|
||||
s = list(s)
|
||||
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
||||
|
||||
# -------------------- vmap test helpers --------------------
|
||||
|
||||
ensure_bdim_p = core.Primitive('ensure_bdim')
|
||||
ensure_bdim_p.def_abstract_eval(lambda x, **kwargs: core.raise_to_shaped(x))
|
||||
def _ensure_bdim_batcher(frame, vals_in, dims_in, axis_name, bdim):
|
||||
v, = vals_in
|
||||
d, = dims_in
|
||||
assert d is not batching.not_mapped
|
||||
return jnp.moveaxis(v, d, bdim), bdim
|
||||
batching.collective_rules[ensure_bdim_p] = _ensure_bdim_batcher
|
||||
batching.primitive_batchers[ensure_bdim_p] = lambda v, d: (v[0], d[0])
|
||||
core.axis_substitution_rules[ensure_bdim_p] = partial(jax._src.lax.parallel._subst_all_names_in_param,
|
||||
'axis_name')
|
||||
|
||||
def ensure_bdim(x, axis_name, bdim):
|
||||
return ensure_bdim_p.bind(x, axis_name=(axis_name,), bdim=bdim)
|
||||
|
||||
# -------------------- Axis resources generation --------------------
|
||||
|
||||
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
||||
@ -444,21 +462,26 @@ class XMapTest(XMapTestCase):
|
||||
for vmap_dim_y in [*range(2 + len(xmap_dim_y)), None]:
|
||||
if vmap_dim_x is None and vmap_dim_y is None:
|
||||
continue
|
||||
for vmap_dim_z in range(2 + len(xmap_axes)):
|
||||
for vmap_as_xmap in [False, True]:
|
||||
yield {"testcase_name":
|
||||
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
|
||||
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
|
||||
f"vout={vmap_dim_z}_vmap_as_xmap={vmap_as_xmap}",
|
||||
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
|
||||
"xmap_out_axes": xmap_dim_z,
|
||||
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
|
||||
"vmap_out_axes": vmap_dim_z,
|
||||
"vmap_as_xmap": vmap_as_xmap}
|
||||
for vmap_dim_result in range(3):
|
||||
for vmap_dim_z in range(2 + len(xmap_axes)):
|
||||
for vmap_as_xmap in [False, True]:
|
||||
yield {"testcase_name":
|
||||
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
|
||||
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
|
||||
f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}",
|
||||
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
|
||||
"xmap_out_axes": xmap_dim_z,
|
||||
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
|
||||
"vmap_out_axes": vmap_dim_z,
|
||||
"vmap_result_axis": vmap_dim_result,
|
||||
"vmap_as_xmap": vmap_as_xmap}
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(VmapOfXmapCases()))
|
||||
@ignore_xmap_warning()
|
||||
def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_as_xmap):
|
||||
def testNestedMap(self,
|
||||
xmap_in_axes, xmap_out_axes,
|
||||
vmap_in_axes, vmap_out_axes, vmap_result_axis,
|
||||
vmap_as_xmap):
|
||||
"""Test various vmap(xmap) and xmap(xmap) combinations.
|
||||
|
||||
The outer map always introduces a single dimension, the inner map introduces one or two.
|
||||
@ -474,7 +497,7 @@ class XMapTest(XMapTestCase):
|
||||
xind = ['n', 'k']
|
||||
yind = ['k', 'm']
|
||||
zind = ['n', 'm']
|
||||
f = partial(jnp.einsum, 'nk,km->nm')
|
||||
f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis)
|
||||
|
||||
for pos, name in sorted(xin_x.items()):
|
||||
xshape.insert(pos, xmap_sizes[name])
|
||||
@ -501,7 +524,7 @@ class XMapTest(XMapTestCase):
|
||||
{vin_y: 'v'} if vin_y is not None else {}),
|
||||
out_axes={vmap_out_axes: 'v'})
|
||||
else:
|
||||
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes)
|
||||
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v')
|
||||
|
||||
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
|
||||
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user