[shard-map] handle closed-over vmap tracers

This commit is contained in:
Matthew Johnson 2023-03-30 16:43:04 -07:00
parent 69c9660aab
commit 7c3c46c807
2 changed files with 24 additions and 9 deletions

View File

@ -39,7 +39,7 @@ from jax._src import util
from jax._src.core import Tracer
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg)
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
@ -721,13 +721,7 @@ def _shard_map_batch(
else ns for ns, d in zip(new_in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_names = out_names_thunk()
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, out_dims())]
if spmd_axis_name is not None:
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(out_names_, out_dims())]
return out_names_
return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk())
new_params = dict(mesh=mesh, in_names=new_in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
@ -737,6 +731,28 @@ def _shard_map_batch(
return map(make_tracer, out_vals, out_dims())
batching.BatchTrace.process_shard_map = _shard_map_batch
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep):
del mesh, in_names, out_names_thunk, check_rep
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
m = trace.main
def todo(vals):
trace = m.with_cur_sublevel()
return map(partial(batching.BatchTracer, trace), vals, dims, srcs)
out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims)
return vals, (todo, out_names_transform)
batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process
def _batch_out_names(spmd_axis_name, dims, out_names):
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, dims)]
if spmd_axis_name is not None:
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(out_names_, dims)]
return out_names_
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,

View File

@ -858,7 +858,6 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
sample(config.FLAGS.jax_num_generated_cases,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
raise unittest.SkipTest("need BatchTrace.post_process_shard_map") # TODO
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)