mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shard-map] handle closed-over vmap tracers
This commit is contained in:
parent
69c9660aab
commit
7c3c46c807
@ -39,7 +39,7 @@ from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, fft, linalg)
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
@ -721,13 +721,7 @@ def _shard_map_batch(
|
||||
else ns for ns, d in zip(new_in_names, in_dims)]
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
out_names = out_names_thunk()
|
||||
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
|
||||
for ax in names} for names, d in zip(out_names, out_dims())]
|
||||
if spmd_axis_name is not None:
|
||||
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
|
||||
else ns for ns, d in zip(out_names_, out_dims())]
|
||||
return out_names_
|
||||
return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk())
|
||||
|
||||
new_params = dict(mesh=mesh, in_names=new_in_names,
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
|
||||
@ -737,6 +731,28 @@ def _shard_map_batch(
|
||||
return map(make_tracer, out_vals, out_dims())
|
||||
batching.BatchTrace.process_shard_map = _shard_map_batch
|
||||
|
||||
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep):
|
||||
del mesh, in_names, out_names_thunk, check_rep
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
m = trace.main
|
||||
def todo(vals):
|
||||
trace = m.with_cur_sublevel()
|
||||
return map(partial(batching.BatchTracer, trace), vals, dims, srcs)
|
||||
out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims)
|
||||
return vals, (todo, out_names_transform)
|
||||
batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process
|
||||
|
||||
def _batch_out_names(spmd_axis_name, dims, out_names):
|
||||
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
|
||||
for ax in names} for names, d in zip(out_names, dims)]
|
||||
if spmd_axis_name is not None:
|
||||
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
|
||||
else ns for ns, d in zip(out_names_, dims)]
|
||||
return out_names_
|
||||
|
||||
|
||||
# Autodiff
|
||||
|
||||
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
|
@ -858,7 +858,6 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
sample(config.FLAGS.jax_num_generated_cases,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("need BatchTrace.post_process_shard_map") # TODO
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user