1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 06:36:07 +00:00

Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.

PiperOrigin-RevId: 691086496
This commit is contained in:
Dougal Maclaurin 2024-10-29 11:03:49 -07:00 committed by jax authors
parent c67cf51f15
commit c36e1f7c1a
47 changed files with 1413 additions and 2633 deletions

@ -701,20 +701,17 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error
def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, def remat_vmap(axis_data, args, dims, *, jaxpr, **params):
jaxpr, **params):
assert not jaxpr.constvars assert not jaxpr.constvars
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
pe.close_jaxpr(jaxpr), axis_size, dims, pe.close_jaxpr(jaxpr), axis_data, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars), [batching.zero_if_mapped] * len(jaxpr.outvars))
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
if consts: if consts:
jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched) jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched)
out_dims = [0 if b else None for b in out_batched] out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None) batching.fancy_primitive_batchers[remat_p] = remat_vmap
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn

@ -34,7 +34,7 @@ from typing import (Any, Literal, NamedTuple, TypeVar, overload,
import weakref import weakref
import numpy as np import numpy as np
from contextlib import contextmanager, ExitStack from contextlib import contextmanager
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import stages from jax._src import stages
@ -989,10 +989,10 @@ def vmap(fun: F,
axis_size_ = (axis_size if axis_size is not None else axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try: try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
out_flat = batching.batch( out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat, flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
spmd_axis_name=spmd_axis_name
).call_wrapped(*args_flat) ).call_wrapped(*args_flat)
except batching.SpecMatchError as e: except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
@ -1546,16 +1546,13 @@ def _cpp_pmap(
is_explicit_global_axis_size=p.is_explicit_global_axis_size, is_explicit_global_axis_size=p.is_explicit_global_axis_size,
) )
map_bind_continuation, top_trace, fun_, tracers, params = (
core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun,
*p.flat_args, **params))
execute: Callable | None = None execute: Callable | None = None
if isinstance(top_trace, core.EvalTrace): with core.take_current_trace() as trace:
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params) if isinstance(trace, core.EvalTrace):
out = map_bind_continuation(execute(*tracers)) execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
else: out = execute(*p.flat_args)
out = map_bind_continuation( else:
pxla.xla_pmap_p.process(top_trace, fun_, tracers, params)) out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
out_tree, out_flat = p.out_tree, out out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree() out_pytree_def = out_tree()
@ -1802,7 +1799,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
... ...
>>> jax.jvp(f, (2.,), (3.,)) >>> jax.jvp(f, (2.,), (3.,))
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.) >>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y) >>> print(y)
3.2681944 3.2681944
@ -2160,9 +2157,7 @@ def make_jaxpr(
@wraps(fun) @wraps(fun)
@api_boundary @api_boundary
def make_jaxpr_f(*args, **kwargs): def make_jaxpr_f(*args, **kwargs):
with ExitStack() as stack: with core.extend_axis_env_nd(axis_env or []):
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
traced = jit(fun, static_argnums=static_argnums, traced = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).trace(*args, **kwargs) abstracted_axes=abstracted_axes).trace(*args, **kwargs)
# `jit` converts tracers in consts to args but that breaks the semantics of # `jit` converts tracers in consts to args but that breaks the semantics of

@ -633,7 +633,6 @@ def io_callback(
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype), flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
flat_shape_dtypes) flat_shape_dtypes)
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind( out_flat = io_callback_p.bind(
*flat_args, *flat_args,
callback=_FlatCallback(callback, in_tree), callback=_FlatCallback(callback, in_tree),

@ -217,7 +217,9 @@ def trace_context():
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value, compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value, numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value, numpy_dtype_promotion.value, dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value, default_device.value, random_seed_offset.value,
threefry_partitionable.value, threefry_partitionable.value,
threefry_gpu_kernel_lowering.value, threefry_gpu_kernel_lowering.value,
@ -832,6 +834,7 @@ class _GlobalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None default_matmul_precision: Any | None = None
dynamic_shapes: bool = False dynamic_shapes: bool = False
eager_constant_folding: bool = False
random_seed_offset: int = 0 random_seed_offset: int = 0
threefry_partitionable: bool = False threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False threefry_gpu_kernel_lowering: bool = False
@ -858,7 +861,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
The initialization, which uses both config.py and core.py is done using The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports. `_update_thread_local_jit_state` in core.py to prevent circular imports.
""" """
dynamic_trace_state: Any | None = None trace_state: Any | None = None
axis_env_state: Hashable = () axis_env_state: Hashable = ()
mesh_context_manager: Hashable = () mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = () compute_on_context_manager: Hashable = ()
@ -873,6 +876,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None dynamic_shapes: bool | None = None
eager_constant_folding : bool | None = None
random_seed_offset: int | None = None random_seed_offset: int | None = None
threefry_partitionable: bool | None = None threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None threefry_gpu_kernel_lowering: bool | None = None
@ -909,7 +913,6 @@ def update_thread_local_jit_state(**kw):
tmp = context._replace(**kw) tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
# TODO(b/214340779): remove flag when XLA:CPU is improved. # TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = bool_state( jax2tf_associative_scan_reductions = bool_state(
name='jax2tf_associative_scan_reductions', name='jax2tf_associative_scan_reductions',
@ -1163,6 +1166,11 @@ sharding_in_types = bool_state(
update_thread_local_hook=lambda val: update_thread_local_jit_state( update_thread_local_hook=lambda val: update_thread_local_jit_state(
sharding_in_types=val)) sharding_in_types=val))
data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
default=False,
help=('When True, falls back to trace dispatch based on data dependence '
'instead of throwing an escaped tracer error.'))
softmax_custom_jvp = bool_state( softmax_custom_jvp = bool_state(
name='jax_softmax_custom_jvp', name='jax_softmax_custom_jvp',
@ -1530,6 +1538,16 @@ dynamic_shapes = bool_state(
update_thread_local_hook=lambda val: \ update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val)) update_thread_local_jit_state(dynamic_shapes=val))
# This is for stackless backward compat with e.g. equinox
eager_constant_folding = bool_state(
name='eager_constant_folding',
default=False,
help=('Attempt constant folding during staging.'),
update_global_hook=lambda val: \
_update_global_jit_state(eager_constant_folding=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(eager_constant_folding=val))
# This flag is temporary during rollout of the remat barrier. # This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints. # TODO(parkers): Remove if there are no complaints.
remat_opt_barrier = bool_state( remat_opt_barrier = bool_state(

File diff suppressed because it is too large Load Diff

@ -138,9 +138,9 @@ def maybe_bdim_at_front(x, bdim):
# axes instead of accepting and matching a given spec of output axes. Assumes # axes instead of accepting and matching a given spec of output axes. Assumes
# `f` is pytree-flattened # `f` is pytree-flattened
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
f, out_axes = batching.batch_subtrace(f) axis_data = batching.AxisData(axis_name, axis_size, None)
f = batching._batch_outer(f, axis_name, axis_size, in_axes, tag = core.TraceTag()
batching.BatchTrace, None) f, out_axes = batching.batch_subtrace(f, tag, axis_data, in_axes)
outs = f.call_wrapped(*args) outs = f.call_wrapped(*args)
return outs, out_axes() return outs, out_axes()

@ -354,25 +354,12 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
class CustomJVPCallPrimitive(core.Primitive): class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True multiple_results = True
def bind(self, fun, jvp, *args, symbolic_zeros): def bind_with_trace(self, trace, args, params):
args = map(core.full_lower, args) fun, jvp, tracers = args[0], args[1], args[2:]
top_trace = core.find_top_trace(args) return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params)
fun, env_trace_todo1 = process_env_traces(
fun, self, top_trace and top_trace.level, False)
jvp, env_trace_todo2 = process_env_traces(
jvp, self, top_trace and top_trace.level, True)
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,
symbolic_zeros=symbolic_zeros)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, _, *args): def impl(self, fun, _, *args):
with core.new_sublevel(): raise NotImplementedError
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, jvp_was_run: bool):
return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run)
def get_bind_params(self, params): def get_bind_params(self, params):
new_params = dict(params) new_params = dict(params)
@ -402,24 +389,6 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
return [*out_primals, *out_tangents] return [*out_primals, *out_tangents]
return jvp return jvp
@partial(lu.transformation_with_aux, use_eq_store=True)
def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
outs = yield args, {}
todo = []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo = primitive.post_process(trace, outs, jvp_was_run)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
@ -824,55 +793,12 @@ def _temporary_shape_exception(a, a_) -> bool:
class CustomVJPCallPrimitive(core.CallPrimitive): class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive initial_style: core.Primitive
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): def bind_with_trace(self, trace, args, params):
args = map(core.full_lower, args) fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:]
top_trace = core.find_top_trace(args) return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params)
fun, env_trace_todo1 = process_env_traces(
fun, self, top_trace and top_trace.level, False)
fwd, env_trace_todo2 = process_env_traces_fwd(
fwd, top_trace and top_trace.level, out_trees)
tracers = map(top_trace.full_raise, args)
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if fst:
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
else:
env_trace_todo, bwd_transform = env_trace_todo
bwd = _apply_bwd_transform(bwd_transform, bwd)
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
with core.new_sublevel():
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_vjp_call(out_tracers, params)
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
@partial(lu.transformation_with_aux, use_eq_store=True)
def process_env_traces_fwd(level: int, out_trees, *args):
outs = yield args, {}
todo = []
bwd_transforms = []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo, bwd_xform = trace.post_process_custom_vjp_call_fwd(outs, out_trees)
todo.append(cur_todo)
bwd_transforms.append(bwd_xform)
yield outs, (tuple(todo), tuple(bwd_transforms))
def _apply_bwd_transform(todos, bwd): def _apply_bwd_transform(todos, bwd):
todos_list = list(todos) todos_list = list(todos)
while todos_list: while todos_list:
@ -889,7 +815,7 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
f'Effects not supported in `custom_vjp`: {disallowed_effects}') f'Effects not supported in `custom_vjp`: {disallowed_effects}')
return fun_jaxpr.out_avals, fun_jaxpr.effects return fun_jaxpr.out_avals, fun_jaxpr.effects
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr') custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p.multiple_results = True custom_vjp_call_jaxpr_p.multiple_results = True
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
@ -921,18 +847,16 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap( def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, axis_data, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)] else x for x, d in zip(args, in_dims)]
in_batched = [d is not not_mapped for d in in_dims] in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts]) _, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr( batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, fun_jaxpr, axis_data, in_batched, False)
main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] out_dims2 = []
@ -940,16 +864,15 @@ def _custom_vjp_call_jaxpr_vmap(
def batched_fwd_jaxpr_thunk(*zeros): def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, fwd_jaxpr, axis_data, args_batched, False)
main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched]) out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0] fwd_out_dims = lambda: out_dims2[0]
tag = core.TraceTag()
batched_bwd = batching.batch_custom_vjp_bwd( batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, bwd, tag, axis_data, fwd_out_dims, fwd_args_batched)
spmd_axis_name)
batched_outs = custom_vjp_call_jaxpr_p.bind( batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr, *args, fun_jaxpr=batched_fun_jaxpr,
@ -957,10 +880,7 @@ def _custom_vjp_call_jaxpr_vmap(
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
out_dims = out_dims2[0] if out_dims2 else out_dims1 out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
@ -1144,11 +1064,12 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool: def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value # False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise. # with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/jax-ml/jax/issues/6415 for motivation. # See https://github.com/google/jax/issues/6415 for motivation.
x = core.full_lower(x)
if not isinstance(x, core.Tracer): if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed. # If x is not a Tracer, it can't be perturbed.
return False return False
elif isinstance(x, ad.JVPTracer) and isinstance(x.tangent, ad.Zero):
return _maybe_perturbed(x.primal)
elif isinstance(x, pe.DynamicJaxprTracer): elif isinstance(x, pe.DynamicJaxprTracer):
# If x is a DynamicJaxprTracer then we're staging out; differentiation could # If x is a DynamicJaxprTracer then we're staging out; differentiation could
# happen later, but some types always have trivial tangents. # happen later, but some types always have trivial tangents.
@ -1532,7 +1453,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
return fwd_jaxpr.out_avals, fwd_jaxpr.effects return fwd_jaxpr.out_avals, fwd_jaxpr.effects
def _remat_opt_vmap( def _remat_opt_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, axis_data, args, in_dims,
*, *,
num_consts: int, num_consts: int,
num_res: int, num_res: int,
@ -1541,11 +1462,9 @@ def _remat_opt_vmap(
): ):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)] else x for x, d in zip(args, in_dims)]
in_batched = [d is not not_mapped for d in in_dims] in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, in_batched, False, fwd_jaxpr, axis_data, in_batched, False)
axis_name, spmd_axis_name, main_type)
extra_consts = batched_fwd_jaxpr.consts extra_consts = batched_fwd_jaxpr.consts
batched_fwd_jaxpr = pe.close_jaxpr( batched_fwd_jaxpr = pe.close_jaxpr(
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
@ -1557,8 +1476,7 @@ def _remat_opt_vmap(
def batched_fun_jaxpr_thunk(): def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr( batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name, fun_jaxpr, axis_data, prim_batched, False)
main_type)
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts
batched_outs = remat_opt_p.bind(*extra_consts, *args, batched_outs = remat_opt_p.bind(*extra_consts, *args,
@ -1592,7 +1510,7 @@ def _remat_opt_jvp(
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))
@pe._memoize # @pe._memoize
def fun_jvp_jaxpr_thunk(): def fun_jvp_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
in_nz = [True] * len(primals) in_nz = [True] * len(primals)
@ -1666,8 +1584,9 @@ remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
xla.register_initial_style_primitive(remat_opt_p) xla.register_initial_style_primitive(remat_opt_p)
mlir.register_lowering(remat_opt_p, mlir.lower_fun( mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True)) _remat_opt_impl, multiple_results=True))
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce pe.dce_rules[remat_opt_p] = _remat_opt_dce

@ -458,7 +458,9 @@ class custom_partitioning:
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning") "custom_partitioning")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts) assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
out_flat = custom_partitioning_p.bind( out_flat = custom_partitioning_p.bind(

@ -155,18 +155,9 @@ class CustomTransposePrimitive(core.Primitive):
map_primitive = False map_primitive = False
multiple_results = True multiple_results = True
def bind(self, call, *args, **params): def bind_with_trace(self, trace, call_args, params):
# TODO(frostig,mattjj): This doesn't handle closures yet, which is call, tracers = call_args[0], call_args[1:]
# a bit involved. Closures are complicated by us binding `call` return trace.process_custom_transpose(self, call, tracers, **params)
# twice in the JVP rule for custom transpose. The `env_trace_todo`
# output by `process_env_traces` due to one of those two bindings
# should be passable to the other, and need to be passed onward
# since the second bind is deferred by partial eval (since it
# typically receives unknowns)
top_trace = core.find_top_trace(args)
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_transpose(self, call, tracers, **params)
return outs
# TODO(frostig,mattjj): consider keeping `call` as a named parameter # TODO(frostig,mattjj): consider keeping `call` as a named parameter
# instead of following this "call primitive" convention. # instead of following this "call primitive" convention.

@ -95,7 +95,8 @@ def apply_primitive(prim, *args, **params):
@util.cache() @util.cache()
def xla_primitive_callable(prim: core.Primitive, **params): def xla_primitive_callable(prim: core.Primitive, **params):
def prim_fun(*args): def prim_fun(*args):
return prim.bind(*args, **params) with config.eager_constant_folding(False):
return prim.bind(*args, **params)
prim_fun.__name__ = prim.name prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name prim_fun.__qualname__ = prim.name
return api.jit(prim_fun) return api.jit(prim_fun)

@ -814,7 +814,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
int2, int2,
int4, int4,
uint2, uint2,
uint4, uint4
] ]
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"

@ -29,7 +29,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten,
from jax._src import core from jax._src import core
from jax._src import source_info_util from jax._src import source_info_util
from jax._src.ad_util import ( from jax._src.ad_util import (
add_jaxvals, replace_internal_symbolic_zeros, zeros_like_jaxval, add_jaxvals, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
@ -69,16 +69,15 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
fun, aux = jvp_subtrace_aux(fun) fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate, transform_stack), aux return jvpfun(fun, instantiate, transform_stack), aux
@lu.transformation @lu.transformation
def jvpfun(instantiate, transform_stack, primals, tangents): def jvpfun(instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents] and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext()) else contextlib.nullcontext())
with core.new_main(JVPTrace) as main, ctx: with ctx:
out_primals, out_tangents = yield (main, primals, tangents), {} out_primals, out_tangents = yield (tag, primals, tangents), {}
del main
if type(instantiate) is bool: if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents) instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(t) if inst else t for t, inst out_tangents = [instantiate_zeros(t) if inst else t for t, inst
@ -86,35 +85,26 @@ def jvpfun(instantiate, transform_stack, primals, tangents):
yield out_primals, out_tangents yield out_primals, out_tangents
@lu.transformation @lu.transformation
def jvp_subtrace(main, primals, tangents): def jvp_subtrace(tag, primals, tangents):
trace = JVPTrace(main, core.cur_sublevel()) with core.take_current_trace() as parent_trace:
for x in list(primals) + list(tangents): trace = JVPTrace(parent_trace, tag)
if isinstance(x, Tracer): in_tracers = [maybe_jvp_tracer(trace, x, t)
if x._trace.level >= trace.level: for x, t in zip(primals, tangents)]
raise core.escaped_tracer_error( with core.set_current_trace(trace):
x, f"Tracer from a higher level: {x} in trace {trace}") ans = yield in_tracers, {}
assert x._trace.level < trace.level out = unzip2(map(trace.to_primal_tangent_pair, ans))
in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x yield out
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
yield unzip2([(out_tracer.primal, out_tracer.tangent)
for out_tracer in out_tracers])
@lu.transformation_with_aux @lu.transformation_with_aux
def jvp_subtrace_aux(main, primals, tangents): def jvp_subtrace_aux(tag, primals, tangents):
trace = JVPTrace(main, core.cur_sublevel()) with core.take_current_trace() as parent_trace:
for x in list(primals) + list(tangents): trace = JVPTrace(parent_trace, tag)
if isinstance(x, Tracer): with core.set_current_trace(trace):
assert x._trace.level < trace.level ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {}
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
ans_tracers = map(trace.full_raise, ans) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) else x for x in aux]
aux_primals = [core.full_lower(x.primal) yield (out_primals, out_tangents), aux_primals
if isinstance(x, JVPTracer) and x._trace.level == trace.level
else x for x in aux]
yield (out_primals, out_tangents), aux_primals
def linearize(traceable, *primals, **kwargs): def linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False) has_aux = kwargs.pop('has_aux', False)
@ -166,7 +156,6 @@ def unpair_pval(pval):
aval_1, aval_2 = aval aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2) return (aval_1, const_1), (aval_2, const_2)
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type # NOTE: The FIXMEs below are caused by primal/tangent mixups (type
# errors if you will) # errors if you will)
def backward_pass(jaxpr: core.Jaxpr, transform_stack, def backward_pass(jaxpr: core.Jaxpr, transform_stack,
@ -281,37 +270,40 @@ def nonzero_tangent_outputs(*args, **kwargs):
class JVPTrace(Trace): class JVPTrace(Trace):
def __init__(self, parent_trace, tag):
self.tag = tag
self.parent_trace = parent_trace
def pure(self, val): def to_primal_tangent_pair(self, val):
tangent_zero = Zero.from_primal_value(val) if isinstance(val, JVPTracer) and val._trace.tag is self.tag:
return JVPTracer(self, val, tangent_zero) return (val.primal, val.tangent)
else:
def lift(self, val): tangent_zero = Zero.from_primal_value(val)
tangent_zero = Zero.from_primal_value(val) return (val, tangent_zero)
return JVPTracer(self, val, tangent_zero)
def sublift(self, val):
return JVPTracer(self, val.primal, val.tangent)
def process_primitive(self, primitive, tracers, params): def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
jvp = primitive_jvps.get(primitive) jvp = primitive_jvps.get(primitive)
if not jvp: if not jvp:
msg = f"Differentiation rule for '{primitive}' not implemented" msg = f"Differentiation rule for '{primitive}' not implemented"
raise NotImplementedError(msg) raise NotImplementedError(msg)
primal_out, tangent_out = jvp(primals_in, tangents_in, **params) with core.set_current_trace(self.parent_trace):
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
if primitive.multiple_results: if primitive.multiple_results:
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)] return [maybe_jvp_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
else: else:
return JVPTracer(self, primal_out, tangent_out) return maybe_jvp_tracer(self, primal_out, tangent_out)
def process_call(self, call_primitive, f, tracers, params): def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results assert call_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
which_nz = [ type(t) is not Zero for t in tangents] which_nz = [ type(t) is not Zero for t in tangents]
tangents = [t if type(t) is not Zero else None for t in tangents] tangents = [t if type(t) is not Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents)) args, in_tree = tree_flatten((primals, tangents))
f_jvp = jvp_subtrace(f, self.main) f_jvp = jvp_subtrace(f, self.tag)
f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp)
if isinstance(call_primitive, core.MapPrimitive): if isinstance(call_primitive, core.MapPrimitive):
in_axes = params['in_axes'] in_axes = params['in_axes']
@ -328,76 +320,59 @@ class JVPTrace(Trace):
f_jvp, out_tree = traceable(f_jvp, in_tree) f_jvp, out_tree = traceable(f_jvp, in_tree)
update_params = call_param_updaters.get(call_primitive) update_params = call_param_updaters.get(call_primitive)
new_params = update_params(params, which_nz) if update_params else params new_params = update_params(params, which_nz) if update_params else params
result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), fun_and_args = (_update_annotation(f_jvp, f.in_type, which_nz),) + tuple(args)
*args, **new_params) result = call_primitive.bind_with_trace(self.parent_trace, fun_and_args, new_params)
primal_out, tangent_out = tree_unflatten(out_tree(), result) primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [Zero.from_primal_value(p) if t is None else t tangent_out = [Zero.from_primal_value(p) if t is None else t
for p, t in zip(primal_out, tangent_out)] for p, t in zip(primal_out, tangent_out)]
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] return [maybe_jvp_tracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
tangents_nz = [type(t) is not Zero for t in tangents]
del primals, tangents
main = self.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
trace = JVPTrace(main, core.cur_sublevel())
return map(partial(JVPTracer, trace), primals, tangents)
if call_primitive.map_primitive:
def out_axes_transform(out_axes):
return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz))
todo = (todo, out_axes_transform)
return out, todo
# The only difference between process_map and process_call is that # The only difference between process_map and process_call is that
# the `in_axes` and `out_axes_thunk` params must be updated; # the `in_axes` and `out_axes_thunk` params must be updated;
# that's handled in process_call. # that's handled in process_call.
process_map = process_call process_map = process_call
post_process_map = post_process_call
def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
primals_in = map(core.full_lower, primals_in) if all(type(t) is Zero for t in tangents_in):
if not symbolic_zeros: return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
tangents_in = map(instantiate_zeros, tangents_in) dict(symbolic_zeros=symbolic_zeros))
else: with core.set_current_trace(self.parent_trace):
tangents_in = map(replace_internal_symbolic_zeros, tangents_in) if not symbolic_zeros:
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) tangents_in = map(instantiate_zeros, tangents_in)
else:
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
outs = f_jvp.call_wrapped(*(tuple(primals_in) + tuple(tangents_in)))
primals_out, tangents_out = split_list(outs, [len(outs) // 2]) primals_out, tangents_out = split_list(outs, [len(outs) // 2])
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out)
def post_process_custom_jvp_call(self, out_tracers, _): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
raise CustomJVPException()
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
symbolic_zeros): symbolic_zeros):
# Local import to prevent an import cycle. primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
from jax._src.lax import lax if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace,
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) (fun, fwd, bwd, *primals_in),
fwd_in = [(core.full_lower(p), type(t) is not Zero) dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
for p, t in zip(primals_in, tangents_in)] fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten fwd_in = [x for pair in fwd_in for x in pair] # flatten
res_and_primals_out = fwd.call_wrapped(*fwd_in) with core.set_current_trace(self.parent_trace):
res_and_primals_out = fwd.call_wrapped(*fwd_in)
_, res_tree = out_trees() _, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
tangents_in = map(instantiate_zeros, tangents_in) with core.set_current_trace(self.parent_trace):
tangents_out = custom_lin_p.bind( tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros) out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
def post_process_custom_vjp_call(self, out_tracers, _):
raise CustomVJPException()
def process_custom_transpose(self, prim, call, tracers, **params): def process_custom_transpose(self, prim, call, tracers, **params):
ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers) ps_in, ts_in = unzip2(map(self.to_primal_tangent_pair, tracers))
res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves]) res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves])
res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves]) res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves])
@ -421,24 +396,18 @@ class JVPTrace(Trace):
raise NotImplementedError( raise NotImplementedError(
'JVP of custom transpose with respect to non-symbolic-zero residuals') 'JVP of custom transpose with respect to non-symbolic-zero residuals')
ps_out = prim.bind(call, *ps_in, **params) with core.set_current_trace(self.parent_trace):
ps_out = prim.bind(call, *ps_in, **params)
lin_ts_in = map(instantiate_zeros, lin_ts_in)
ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params)
lin_ts_in = map(instantiate_zeros, lin_ts_in) return map(partial(maybe_jvp_tracer, self), ps_out, ts_out)
ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params)
return map(partial(JVPTracer, self), ps_out, ts_out)
def join(self, xt, yt):
xz, yz = type(xt) is Zero, type(yt) is Zero
if xz == yz:
return xt, yt
elif yz and not xz:
return xt, zeros_like_jaxval(xt)
elif xz and not yz:
return zeros_like_jaxval(yt), yt
else:
raise TypeError((xt, yt))
def maybe_jvp_tracer(trace, primal, tangent):
if type(tangent) is Zero:
return primal
else:
return JVPTracer(trace, primal, tangent)
class JVPTracer(Tracer): class JVPTracer(Tracer):
__slots__ = ['primal', 'tangent'] __slots__ = ['primal', 'tangent']
@ -452,7 +421,6 @@ class JVPTracer(Tracer):
@property @property
def aval(self): def aval(self):
# TODO(dougalm): add epsilon ball
return get_aval(self.primal) return get_aval(self.primal)
def full_lower(self): def full_lower(self):

@ -14,7 +14,7 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from functools import partial from functools import partial
from typing import Any, Union from typing import Any, Union
@ -29,12 +29,12 @@ from jax._src import linear_util as lu
from jax._src.ad_util import (Zero, instantiate, SymbolicZero, from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros, replace_rule_output_symbolic_zeros,
add_jaxvals, add_jaxvals_p) add_jaxvals, add_jaxvals_p)
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten, from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node) register_pytree_node)
from jax._src.typing import Array from jax._src.typing import Array
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function, canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache) curry, memoize, weakref_lru_cache)
@ -284,8 +284,7 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
def _cont(axis_size, elt, axis): def _cont(axis_size, elt, axis):
return from_elt(trace, axis_size, i, elt, axis) return from_elt(trace, axis_size, i, elt, axis)
return handler(_cont, axis_size, x, spec) return handler(_cont, axis_size, x, spec)
x_ = trace.full_raise(x) val, bdim = trace.to_batch_info(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is RaggedAxis: if type(bdim) is RaggedAxis:
if spec is not jumble_axis: if spec is not jumble_axis:
# TODO(mattjj): improve this error message # TODO(mattjj): improve this error message
@ -293,9 +292,9 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else: else:
try: try:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val) return matchaxis(trace.axis_data.name, axis_size, bdim, spec, val)
except SpecMatchError: except SpecMatchError:
raise SpecMatchError(i, x_.batch_dim, spec) from None raise SpecMatchError(i, x.batch_dim, spec) from None
from_elt_handlers: dict[type, FromEltHandler] = {} from_elt_handlers: dict[type, FromEltHandler] = {}
def make_iota(axis_size: AxisSize) -> Array: def make_iota(axis_size: AxisSize) -> Array:
@ -435,165 +434,118 @@ class BatchTracer(Tracer):
else: # TODO(mattjj): could handle the RaggedAxis case? else: # TODO(mattjj): could handle the RaggedAxis case?
return self return self
@dataclasses.dataclass(frozen=True)
class AxisData:
name : Any
size : Any
spmd_name : Any
class BatchTrace(Trace): class BatchTrace(Trace):
def __init__(self, *args, axis_name, spmd_axis_name = None): def __init__(self, parent_trace, tag, axis_data):
super().__init__(*args) self.parent_trace = parent_trace
self.axis_name = axis_name assert isinstance(axis_data, AxisData)
self.spmd_axis_name = spmd_axis_name self.axis_data = axis_data
self.tag = tag
def pure(self, val): def to_batch_info(self, val):
return BatchTracer(self, val, not_mapped, source_info_util.current()) if isinstance(val, BatchTracer) and val._trace.tag is self.tag:
return val.val, val.batch_dim
def lift(self, val):
return BatchTracer(self, val, not_mapped, source_info_util.current())
def sublift(self, val):
return BatchTracer(self, val.val, val.batch_dim, source_info_util.current())
def get_primitive_batcher(self, primitive, frame):
if primitive in primitive_batchers:
return primitive_batchers[primitive]
elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers:
return partial(spmd_axis_primitive_batchers[primitive],
self.spmd_axis_name, frame.size, frame.name,
frame.main_trace.trace_type)
elif primitive in axis_primitive_batchers:
return self.get_axis_primitive_batcher(primitive, frame)
msg = "Batching rule for '{}' not implemented"
raise NotImplementedError(msg.format(primitive))
def get_axis_primitive_batcher(self, primitive, frame):
return partial(axis_primitive_batchers[primitive],
frame.size, frame.name, frame.main_trace.trace_type)
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
if any(d is not not_mapped for d in dims):
sizes = (x.shape[d] if type(d) is int else d.size
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
else: else:
axis_size = None # can't be inferred from data return val, not_mapped
if self.axis_name is core.no_axis_name:
assert axis_size is not None # must be inferable from data
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
frame = core.axis_frame(self.axis_name, self.main)
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
assert frame.main_trace is self.main
return frame
def process_primitive(self, primitive, tracers, params): def process_primitive(self, p, tracers, params):
if config.dynamic_shapes.value: if config.dynamic_shapes.value:
primitive.abstract_eval(*(t.aval for t in tracers), **params) p.abstract_eval(*(map(core.get_aval, tracers)), **params)
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) vals_in, dims_in = unzip2(map(self.to_batch_info, tracers))
is_axis_primitive = primitive in axis_primitive_batchers args_not_mapped = all(bdim is not_mapped for bdim in dims_in)
used_names = core.used_axis_names(primitive, params) if p in fancy_primitive_batchers:
if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names): if (args_not_mapped
frame = self.get_frame(vals_in, dims_in) and p in skippable_batchers
batcher_primitive = self.get_axis_primitive_batcher(primitive, frame) and not any(self.axis_data.name == axis_name
val_out, dim_out = batcher_primitive(vals_in, dims_in, **params) for axis_name in skippable_batchers[p](params))):
elif all(bdim is not_mapped for bdim in dims_in): # no-op shortcut
return primitive.bind(*vals_in, **params) return p.bind_with_trace(self.parent_trace, vals_in, params)
else:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = fancy_primitive_batchers[p](self.axis_data, vals_in, dims_in, **params)
elif args_not_mapped:
# no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
elif p in primitive_batchers:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params)
else: else:
frame = self.get_frame(vals_in, dims_in) raise NotImplementedError("Batching rule for '{}' not implemented".format(p))
batched_primitive = self.get_primitive_batcher(primitive, frame)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
src = source_info_util.current() src = source_info_util.current()
if primitive.multiple_results: if p.multiple_results:
return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)] with core.set_current_trace(self.parent_trace): # val_out may be lazy map
return [BatchTracer(self, x, d, src) if d is not not_mapped else x
for x, d in zip(val_out, dim_out)]
else: else:
return BatchTracer(self, val_out, dim_out, src) return (BatchTracer(self, val_out, dim_out, src)
if dim_out is not not_mapped else val_out)
def process_call(self, call_primitive, f, tracers, params): def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results assert call_primitive.multiple_results
params = dict(params, name=params.get('name', f.__name__)) params = dict(params, name=params.get('name', f.__name__))
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) vals, dims = unzip2(map(self.to_batch_info, tracers))
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
segment_lens, dims = indirectify_ragged_axes(dims) segment_lens, dims = indirectify_ragged_axes(dims)
f_, dims_out = batch_subtrace(f, self.main, tuple(dims)) f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims))
f_ = _update_annotation( f_ = _update_annotation(
f_, f.in_type, axis_size, self.axis_name, dims, segment_lens) f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens)
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
with core.set_current_trace(self.parent_trace):
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out())
src = source_info_util.current() src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)] return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
def post_process_call(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) vals, dims = unzip2(map(self.to_batch_info, tracers))
if all(dim is not_mapped for dim in dims): # The logic for the dimension math below is as follows:
return map_primitive.bind(f, *vals, **params) # ╔═════════════╦════════════════════════════════════════╦═══════════╗
else: # ║ d / in_axis ║ None ║ int ║
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1 # ╠═════════════╬════════════════════════════════════════╩═══════════╣
# The logic for the dimension math below is as follows: # ║ None ║ No extra axis, so in_axis unaffected ║
# ╔═════════════╦════════════════════════════════════════╦═══════════╗ # ╠═════════════╬════════════════════════════════════════╦═══════════╣
# ║ d / in_axis ║ None ║ int ║ # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
# ╠═════════════╬════════════════════════════════════════╩═══════════╣ # ╚═════════════╩════════════════════════════════════════╩═══════════╝
# ║ None ║ No extra axis, so in_axis unaffected ║ # When both d and in_axis are defined then:
# ╠═════════════╬════════════════════════════════════════╦═══════════╣ # - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
# When both d and in_axis are defined then:
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
new_in_axes = tuple(
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
for d, in_axis in zip(dims, params['in_axes']))
new_dims = tuple(
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
for d, in_axis in zip(dims, params['in_axes']))
f, dims_out = batch_subtrace(f, self.main, new_dims)
out_axes_thunk = params['out_axes_thunk']
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
vals_out = map_primitive.bind(f, *vals, **new_params)
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
for d, out_axis in zip(dims_out(), out_axes_thunk())]
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def both_mapped(in_out_axis, d): def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped return in_out_axis is not None and d is not not_mapped
def todo(vals): new_in_axes = tuple(
trace = main.with_cur_sublevel() in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s) for d, in_axis in zip(dims, params['in_axes']))
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)] new_dims = tuple(
if call_primitive.map_primitive: d - 1 if both_mapped(in_axis, d) and in_axis < d else d
def out_axes_transform(out_axes): for d, in_axis in zip(dims, params['in_axes']))
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis f, dims_out = batch_subtrace(f, self.tag, self.axis_data, new_dims)
for out_axis, d in zip(out_axes, dims)) out_axes_thunk = params['out_axes_thunk']
todo = (todo, out_axes_transform) # NOTE: This assumes that the choice of the dimensions over which outputs
return vals, todo # are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
with core.set_current_trace(self.parent_trace):
vals_out = map_primitive.bind(f, *vals, **new_params)
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
for d, out_axis in zip(dims_out(), out_axes_thunk())]
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims)
out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals),
dict(symbolic_zeros=symbolic_zeros))
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst: if not fst:
assert out_dims == out_dims[:len(out_dims) // 2] * 2 assert out_dims == out_dims[:len(out_dims) // 2] * 2
@ -601,34 +553,18 @@ class BatchTrace(Trace):
src = source_info_util.current() src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
if jvp_was_run:
primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):]
assert primal_dims == tangent_dims
primal_srcs = srcs[:len(vals)]
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
else:
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees,
symbolic_zeros): # pytype: disable=signature-mismatch symbolic_zeros): # pytype: disable=signature-mismatch
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
if d is not not_mapped}
fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]]
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims)
out_dims2, in_dims, self.main.trace_type,
self.spmd_axis_name) bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, out_vals = prim.bind_with_trace(self.parent_trace,
symbolic_zeros=symbolic_zeros) (fun, fwd, bwd) + tuple(in_vals),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst: if not fst:
_, res_tree = out_trees() _, res_tree = out_trees()
@ -636,83 +572,46 @@ class BatchTrace(Trace):
src = source_info_util.current() src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def post_process_custom_vjp_call(self, out_tracers, _):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
main, trace_type = self.main, self.main.trace_type
axis_name = self.axis_name
_, res_tree = out_trees()
num_res = res_tree.num_leaves
res_dims, primal_dims = split_list(dims, [num_res])
_, primal_srcs = split_list(srcs, [num_res])
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
def bwd_transform(bwd):
return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
trace_type, self.spmd_axis_name)
return vals, todo, bwd_transform
def _main_trace_for_axis_names(main_trace: core.MainTrace,
axis_name: Iterable[AxisName],
) -> bool:
# This function exists to identify whether a main trace corresponds to any of
# the axis names used by a primitive. Axis names alone aren't enough because
# axis names can shadow, so we use the main trace as a tag.
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
### API for batching callables with vmappable inputs and outputs ### API for batching callables with vmappable inputs and outputs
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size, def batch(fun: lu.WrappedFun, axis_data,
in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace, in_dims, out_dim_dests) -> lu.WrappedFun:
spmd_axis_name: tuple[AxisName, ...] | None = None
) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker # we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_size, out_dim_dests) f = _batch_inner(fun, axis_data, out_dim_dests)
return _batch_outer(f, axis_name, axis_size, in_dims, main_type, return _batch_outer(f, axis_data, in_dims)
spmd_axis_name)
@lu.transformation @lu.transformation
def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name, def _batch_outer(axis_data, in_dims, *in_vals):
*in_vals): tag = TraceTag()
with core.new_main( with source_info_util.transform_name_stack('vmap'):
main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main: outs, trace = yield (tag, in_dims, *in_vals), {}
with core.extend_axis_env(axis_name, axis_size, main): with core.ensure_no_leaks(trace): del trace
with source_info_util.transform_name_stack('vmap'):
outs = yield (main, in_dims, *in_vals), {}
del main
yield outs yield outs
@lu.transformation @lu.transformation
def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals): def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims in_dims = in_dims() if callable(in_dims) else in_dims
trace = main.with_cur_sublevel() with core.take_current_trace() as parent_trace:
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0, trace = BatchTrace(parent_trace, tag, axis_data)
source_info_util.current())) idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) source_info_util.current()))
outs = yield in_tracers, {} in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with core.set_current_trace(trace):
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
outs = yield in_tracers, {}
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_size), range(len(outs)), out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
outs, out_dim_dests) outs, out_dim_dests)
yield out_vals
yield out_vals, trace
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun, def vtile(f_flat: lu.WrappedFun,
in_axes_flat: tuple[int | None, ...], in_axes_flat: tuple[int | None, ...],
out_axes_flat: tuple[int | None, ...], out_axes_flat: tuple[int | None, ...],
tile_size: int | None, tile_size: int | None,
axis_name: AxisName, axis_name: AxisName):
main_type: type[BatchTrace] = BatchTrace):
@curry @curry
def tile_axis(arg, axis: int | None, tile_size): def tile_axis(arg, axis: int | None, tile_size):
if axis is None: if axis is None:
@ -736,23 +635,24 @@ def vtile(f_flat: lu.WrappedFun,
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat) yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(batch( axis_data = AxisData(axis_name, tile_size, None)
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type)) return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
### API for batching functions with jaxpr type inputs and outputs ### API for batching functions with jaxpr type inputs and outputs
@lu.transformation_with_aux @lu.transformation_with_aux
def batch_subtrace(main, in_dims, *in_vals): def batch_subtrace(tag, axis_data, in_dims, *in_vals):
trace = main.with_cur_sublevel() with core.take_current_trace() as parent_trace:
in_dims = in_dims() if callable(in_dims) else in_dims trace = BatchTrace(parent_trace, tag, axis_data)
in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) with core.set_current_trace(trace):
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) in_dims = in_dims() if callable(in_dims) else in_dims
if dim is not None else x for x, dim in zip(in_vals, in_dims)] in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims)
outs = yield in_tracers, {} in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
out_tracers = map(trace.full_raise, outs) if dim is not None else x for x, dim in zip(in_vals, in_dims)]
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) outs = yield in_tracers, {}
segment_lens, out_dims = indirectify_ragged_axes(out_dims) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
yield (*segment_lens, *out_vals), out_dims segment_lens, out_dims = indirectify_ragged_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
def indirectify_ragged_axes(dims): def indirectify_ragged_axes(dims):
if not any(type(d) is RaggedAxis for d in dims): if not any(type(d) is RaggedAxis for d in dims):
@ -823,38 +723,30 @@ def resolve_ragged_axes_against_inputs_outputs(in_vals, out_vals, dims):
# Can reuse same pattern for all dynamic shape stuff. # Can reuse same pattern for all dynamic shape stuff.
def batch_jaxpr2( def batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr, closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize, axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...], in_axes: tuple[int | NotMapped | RaggedAxis, ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: type[BatchTrace],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]:
# This is only ever used in pjit. The difference vs batch_jaxpr is that # This is only ever used in pjit. The difference vs batch_jaxpr is that
# batch_jaxpr2 lets the callee decide which outputs are batched and what # batch_jaxpr2 lets the callee decide which outputs are batched and what
# their batch axes are; whereas batch_jaxpr has to obey caller-imposed # their batch axes are; whereas batch_jaxpr has to obey caller-imposed
# consistency constraints, such as type-agreement across arms of a # consistency constraints, such as type-agreement across arms of a
# `lax.cond`, or input-output agreement for the body of a `lax.scan`. # `lax.cond`, or input-output agreement for the body of a `lax.scan`.
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name, return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes))
spmd_axis_name, main_type)
@weakref_lru_cache @weakref_lru_cache
def _batch_jaxpr2( def _batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr, closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize, axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...], in_axes: tuple[int | NotMapped | RaggedAxis, ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: type[BatchTrace],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size) f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, f = _batch_jaxpr_outer(f, axis_data, in_axes)
main_type)
in_axes2, avals_in = unzip2([ in_axes2, avals_in = unzip2([
handle_ragged(closed_jaxpr.in_avals, dim, aval) handle_ragged(closed_jaxpr.in_avals, dim, aval)
if isinstance(dim, RaggedAxis) else (dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval)
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval) avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval)
if b is not not_mapped else aval if b is not not_mapped else aval
for aval, b in unsafe_zip(avals_in, in_axes2)] for aval, b in unsafe_zip(avals_in, in_axes2)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
@ -868,14 +760,11 @@ def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
new_aval = aval.update(shape=tuple(new_shape)) new_aval = aval.update(shape=tuple(new_shape))
return dim.stacked_axis, new_aval return dim.stacked_axis, new_aval
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
spmd_axis_name, main_type):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst, return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst)
axis_name, spmd_axis_name, main_type)
def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
spmd_axis_name, main_type):
assert (isinstance(instantiate, bool) or assert (isinstance(instantiate, bool) or
isinstance(instantiate, (list, tuple)) and isinstance(instantiate, (list, tuple)) and
all(isinstance(b, bool) for b in instantiate)) all(isinstance(b, bool) for b in instantiate))
@ -883,46 +772,41 @@ def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
instantiate = [instantiate] * len(closed_jaxpr.out_avals) instantiate = [instantiate] * len(closed_jaxpr.out_avals)
in_axes = [0 if b else not_mapped for b in in_batched] in_axes = [0 if b else not_mapped for b in in_batched]
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate] out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest)
axis_name, spmd_axis_name, main_type)
def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
spmd_axis_name, main_type): return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest))
return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes),
tuple(out_axes_dest), axis_name, spmd_axis_name,
main_type)
@weakref_lru_cache @weakref_lru_cache
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
axis_name, spmd_axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size) f, out_axes = _batch_jaxpr_inner(f, axis_data)
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes) f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, f = _batch_jaxpr_outer(f, axis_data, in_axes)
main_type) avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched() return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux @lu.transformation_with_aux
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals): def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
trace = main.with_cur_sublevel() with core.take_current_trace() as parent_trace:
_, in_axes = resolve_ragged_axes(in_vals, in_axes) trace = BatchTrace(parent_trace, tag, axis_data)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val _, in_axes = resolve_ragged_axes(in_vals, in_axes)
for val, dim in zip(in_vals, in_axes)] in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
outs = yield in_tracers, {} for val, dim in zip(in_vals, in_axes)]
out_tracers = map(trace.full_raise, outs) with core.set_current_trace(trace):
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers) with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
new_out_axes = indirectify_ragged_axes_against_inputs_outputs( outs = yield in_tracers, {}
out_axes, in_vals, out_vals) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
yield out_vals, new_out_axes new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
yield out_vals, new_out_axes
@lu.transformation_with_aux @lu.transformation_with_aux
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
*in_vals): *in_vals):
trace = main.with_cur_sublevel() out_vals = yield (trace, in_axes, *in_vals), {}
out_vals = yield (main, in_axes, *in_vals), {}
out_axes = out_axes() out_axes = out_axes()
out_axes_dest = [(None if src is not_mapped else 0) out_axes_dest = [(None if src is not_mapped else 0)
if dst is zero_if_mapped else dst if dst is zero_if_mapped else dst
@ -930,24 +814,18 @@ def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
if len(out_axes_dest) != len(out_axes): if len(out_axes_dest) != len(out_axes):
out_axis_dest, = out_axes_dest out_axis_dest, = out_axes_dest
out_axes_dest = [out_axis_dest] * len(out_axes) out_axes_dest = [out_axis_dest] * len(out_axes)
out_vals = map(partial(matchaxis, trace.axis_name, axis_size), out_vals = map(partial(matchaxis, axis_data.name, axis_data.size),
out_axes, out_axes_dest, out_vals) out_axes, out_axes_dest, out_vals)
out_batched = [dst is not None for dst in out_axes_dest] out_batched = [dst is not None for dst in out_axes_dest]
yield out_vals, out_batched yield out_vals, out_batched
@lu.transformation @lu.transformation
def _batch_jaxpr_outer(axis_name, spmd_axis_name, axis_size, in_dims, main_type, def _batch_jaxpr_outer(axis_data, in_dims, *in_vals):
*in_vals):
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
in_dims = in_dims() if callable(in_dims) else in_dims in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
else ax for x, ax in unsafe_zip(in_vals, in_dims)] else ax for x, ax in unsafe_zip(in_vals, in_dims)]
with core.new_main(main_type, axis_name=axis_name, tag = TraceTag()
spmd_axis_name=spmd_axis_name) as main: out_vals = yield (tag, in_dims, *in_vals), {}
with core.extend_axis_env(axis_name, axis_size, main):
out_vals = yield (main, in_dims, *in_vals), {}
del main
yield out_vals yield out_vals
def _merge_bdims(x, y): def _merge_bdims(x, y):
@ -966,31 +844,33 @@ zero_if_mapped = ZeroIfMapped()
### functions for handling custom_vjp ### functions for handling custom_vjp
@lu.transformation_with_aux @lu.transformation_with_aux
def batch_custom_jvp_subtrace(main, in_dims, *in_vals): def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2) size = axis_data.size
if d is not not_mapped} with core.take_current_trace() as parent_trace:
trace = main.with_cur_sublevel() trace = BatchTrace(parent_trace, tag, axis_data)
in_tracers = [val if dim is None else in_tracers = [val if dim is None else
SymbolicZero(core.mapped_aval(size, dim, val.aval)) SymbolicZero(core.mapped_aval(size, dim, val.aval))
if type(val) is SymbolicZero else BatchTracer(trace, val, dim) if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
for val, dim in zip(in_vals, in_dims * 2)] for val, dim in zip(in_vals, in_dims * 2)]
outs = yield in_tracers, {} with core.set_current_trace(trace):
# TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can outs = yield in_tracers, {}
# be wasteful in the rare case it actually triggers; handle symbolically! # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can
outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] # be wasteful in the rare case it actually triggers; handle symbolically!
out_tracers = map(trace.full_raise, outs) outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs]
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
out_primals = map(partial(matchaxis, trace.axis_name, size), out_primals = map(partial(matchaxis, trace.axis_data.name, size),
out_primal_bds, out_dims, out_primals) out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(matchaxis, trace.axis_name, size), out_tangents = map(partial(matchaxis, trace.axis_data.name, size),
out_tangent_bds, out_dims, out_tangents) out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2 yield out_primals + out_tangents, out_dims * 2
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
main_type, spmd_axis_name): axis_size = axis_data.size
axis_name = axis_data.name
def new_bwd(*args): def new_bwd(*args):
in_dims_ = in_dims() if callable(in_dims) else in_dims in_dims_ = in_dims() if callable(in_dims) else in_dims
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
@ -998,9 +878,7 @@ def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
for x, dim in zip(args, in_dims_)] for x, dim in zip(args, in_dims_)]
in_dims_ = [None if type(x) is SymbolicZero else d in_dims_ = [None if type(x) is SymbolicZero else d
for x, d in zip(args, in_dims_)] for x, d in zip(args, in_dims_)]
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_)
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type,
spmd_axis_name)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests) out_dim_dests)
return bwd_.call_wrapped(*args) return bwd_.call_wrapped(*args)
@ -1039,8 +917,23 @@ BatchingRule = Callable[
tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] tuple[Any, Union[int, None, tuple[Union[int, None], ...]]]
] ]
primitive_batchers : dict[core.Primitive, BatchingRule] = {} primitive_batchers : dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: dict[core.Primitive, Callable] = {} # "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args
spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {} fancy_primitive_batchers: dict[core.Primitive, Callable] = {}
# backwards compat shim. TODO: delete
class AxisPrimitiveBatchersProxy:
def __setitem__(self, prim, batcher):
def wrapped(axis_data, vals, dims, **params):
return batcher(axis_data.size, axis_data.name, None, vals, dims, **params)
fancy_primitive_batchers[prim] = wrapped
axis_primitive_batchers = AxisPrimitiveBatchersProxy()
# Presence in this table allows fancy batchers to be skipped by batch traces for
# irrelevant axes. The Callable takes the params and returns a list of relevant
# axes.
skippable_batchers : dict[core.Primitive, Callable] = {}
def defvectorized(prim): def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim) primitive_batchers[prim] = partial(vectorized_batcher, prim)

File diff suppressed because it is too large Load Diff

@ -16,7 +16,6 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
from contextlib import contextmanager
import collections import collections
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable, Sequence, Iterable from collections.abc import Callable, Sequence, Iterable
@ -374,14 +373,15 @@ def _emap_impl(fun: lu.WrappedFun, *args,
emap_info = EmapInfo(backend, devices) emap_info = EmapInfo(backend, devices)
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes] shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
with core.new_base_main(MapTrace, emap_info=emap_info) as main: trace = MapTrace(axis_name, emap_info)
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main): with core.extend_axis_env_nd([(axis_name, axis_size)]):
t = main.with_cur_sublevel() tracers = [MapTracer(trace, arg, s) for arg, s in zip(args, shard_axes)]
tracers = [MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)] with core.set_current_trace(trace):
ans = fun.call_wrapped(*tracers) ans = fun.call_wrapped(*tracers)
out_tracers = map(t.full_raise, ans)
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers) out_tracers = map(trace.to_map_tracer, ans)
del main outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
out_axes = out_axes_thunk() out_axes = out_axes_thunk()
platform = xb.get_backend(backend).platform platform = xb.get_backend(backend).platform
@ -441,25 +441,33 @@ FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
class MapTrace(core.Trace): class MapTrace(core.Trace):
def __init__(self, *args, emap_info): def __init__(self, axis_name, emap_info):
super().__init__(*args)
self.emap_info = emap_info self.emap_info = emap_info
self.axis_name = axis_name
def pure(self, val): def to_map_tracer(self, val):
return MapTracer(self, val, {}) if isinstance(val, MapTracer):
return val
def sublift(self, tracer): else:
return MapTracer(self, tracer.val, tracer.shard_axes) return MapTracer(self, val, {})
def process_primitive(self, primitive, tracers, params): def process_primitive(self, primitive, tracers, params):
info = self.main.payload["emap_info"] if primitive is jax._src.lax.parallel.axis_index_p:
return self.process_axis_index(**params)
if primitive is jax._src.lax.parallel.psum_p:
f = HashableFunction(
lambda *xs: jax._src.lax.parallel.psum(
xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']),
(primitive, tuple(params.items())))
else:
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
(primitive, tuple(params.items())))
tracers = map(self.to_map_tracer, tracers)
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers]) vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env info = self.emap_info
if f.main_trace is self.main) names = core.get_axis_env().axis_names()
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations
f = HashableFunction(lambda *args: primitive.bind(*args, **params), f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes)
(primitive, tuple(params.items())))
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
with core.eval_context(), jax.disable_jit(False): with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals) outvals = f_mapped(*vals)
if primitive.multiple_results: if primitive.multiple_results:
@ -484,14 +492,12 @@ class MapTrace(core.Trace):
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s} shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)] for v, ax, s in zip(vals, in_axes, shard_axes)]
# TODO(mattjj): use _emap_subtrace here? in_tracers = map(partial(MapTracer, self), vals, shard_axes)
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main): with core.extend_axis_env_nd([(axis_name, axis_size)]):
t = self.main.with_cur_sublevel() with core.set_current_trace(self):
in_tracers = map(partial(MapTracer, t), vals, shard_axes) ans = fun.call_wrapped(*in_tracers)
ans = fun.call_wrapped(*in_tracers) out_tracers = map(self.to_map_tracer, ans)
out_tracers = map(t.full_raise, ans)
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers) out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst) out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
for v, s, dst in zip(out, outaxes, out_axes_thunk())) for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes) return map(partial(MapTracer, self), out, outaxes)
@ -502,11 +508,8 @@ class MapTrace(core.Trace):
"Please open an issue at https://github.com/jax-ml/jax/issues !") "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros # always base main, can drop jvp del prim, jvp, symbolic_zeros # always base main, can drop jvp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) with core.set_current_trace(self):
fun, out_axes = _emap_subtrace(fun, self.main, in_axes) return fun.call_wrapped(*tracers)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(MapTracer, self), out_vals, out_axes())
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros): out_trees, symbolic_zeros):
@ -515,32 +518,18 @@ class MapTrace(core.Trace):
"Please open an issue at https://github.com/jax-ml/jax/issues !") "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) with core.set_current_trace(self):
fun, out_axes = _emap_subtrace(fun, self.main, in_axes) return fun.call_wrapped(*tracers)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(MapTracer, self), out_vals, out_axes())
def process_axis_index(self, frame): def process_axis_index(self, axis_name):
bind = HashableFunction( bind = HashableFunction(
lambda _: jax.lax.axis_index(frame.name), lambda _: jax.lax.axis_index(axis_name),
(jax.lax.axis_index, frame.name)) (jax.lax.axis_index, axis_name))
fake_primitive = FakePrimitive(multiple_results=False, bind=bind) fake_primitive = FakePrimitive(multiple_results=False, bind=bind)
with core.eval_context(): range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name))
range = jax.lax.iota(np.int32, frame.size) dummy_tracer = MapTracer(self, range, {axis_name: 0})
dummy_tracer = MapTracer(self, range, {frame.name: 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {}) return self.process_primitive(fake_primitive, (dummy_tracer,), {})
@lu.transformation_with_aux
def _emap_subtrace(main, in_axes, *in_vals):
t = main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), in_vals, in_axes)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
yield out_vals, out_axes
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int], def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: int | None) -> int | None: annotation: int | None) -> int | None:
if annotation is None: return None if annotation is None: return None
@ -706,11 +695,11 @@ def stage_parallel_callable(
fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk)
else: else:
fun = orig_fun fun = orig_fun
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): with core.extend_axis_env_nd([(pci.axis_name, pci.global_axis_size)]):
with dispatch.log_elapsed_time( with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec", "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pe.debug_info_final(fun, "pmap")) fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info) jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
@ -748,7 +737,8 @@ def get_pmap_jaxpr(
pci = ParallelCallableInfo( pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices, name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals) in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) with core.extend_axis_env_nd([(axis_name, axis_size)]):
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return closed_jaxpr, backend, replicas, shards, pci return closed_jaxpr, backend, replicas, shards, pci
@ -847,7 +837,7 @@ def lower_parallel_callable(
backend.platform) backend.platform)
module_name = f"pmap_{fun.__name__}" module_name = f"pmap_{fun.__name__}"
platforms = lowering_platforms or (backend.platform,) platforms = lowering_platforms or (backend.platform,)
with maybe_extend_axis_env(axis_name, global_axis_size, None): with core.extend_axis_env_nd([(axis_name, global_axis_size)]):
ordered_effects = list( ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects)) effects.ordered_effects.filter_in(closed_jaxpr.effects))
if ordered_effects: if ordered_effects:
@ -1343,7 +1333,7 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
def _pmap_dce_rule(used_outputs, eqn): def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
axis_name = eqn.params["axis_name"] axis_name = eqn.params["axis_name"]
with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None): with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
_, in_axes = partition_list(used_inputs, eqn.params['in_axes']) _, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
@ -1402,21 +1392,6 @@ ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
def _pmap_axis_subst(params, subst, traverse):
if 'call_jaxpr' not in params:
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['axis_name'] else subst(name)
with maybe_extend_axis_env(params['axis_name'],
params['global_axis_size'], None):
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'],
shadowed_subst)
return dict(params, call_jaxpr=new_jaxpr)
core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
def _unravel_index_hlo(axis_env): def _unravel_index_hlo(axis_env):
div = mlir.ir_constant( div = mlir.ir_constant(
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32)) np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
@ -1525,7 +1500,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
if in_axis is not None else in_node if in_axis is not None else in_node
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
with maybe_extend_axis_env(axis_name, global_axis_size, None): with core.extend_axis_env_nd([(axis_name, global_axis_size)]):
sub_ctx = ctx.module_context.replace( sub_ctx = ctx.module_context.replace(
axis_context=sharding_impls.ReplicaAxisContext(new_env)) axis_context=sharding_impls.ReplicaAxisContext(new_env))
sharded_outs, _ = mlir.jaxpr_subcomp( sharded_outs, _ = mlir.jaxpr_subcomp(
@ -3203,9 +3178,3 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
parsed_pspec = sharding_impls.prepare_axis_resources( parsed_pspec = sharding_impls.prepare_axis_resources(
pspec, "pspec to array_mapping") pspec, "pspec to array_mapping")
return _get_array_mapping(parsed_pspec) return _get_array_mapping(parsed_pspec)
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
yield

@ -28,7 +28,6 @@ from jax._src.lax.control_flow.loops import (
fori_loop as fori_loop, fori_loop as fori_loop,
map as map, map as map,
scan as scan, scan as scan,
scan_bind as scan_bind,
scan_p as scan_p, scan_p as scan_p,
_scan_impl as _scan_impl, _scan_impl as _scan_impl,
while_loop as while_loop, while_loop as while_loop,

@ -148,11 +148,6 @@ def switch(index, branches: Sequence[Callable], *operands,
if disallowed_effects: if disallowed_effects:
raise NotImplementedError( raise NotImplementedError(
f'Effects not supported in `switch`: {disallowed_effects}') f'Effects not supported in `switch`: {disallowed_effects}')
if joined_effects:
# Raise index in case of effects to allow data-dependence-based discharging
# of those effects (even if they don't have an explicit data dependence).
index = core.raise_as_much_as_possible(index)
out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs))
return tree_unflatten(out_trees[0], out) return tree_unflatten(out_trees[0], out)
@ -263,10 +258,6 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
f'Effects not supported in `cond`: {disallowed_effects}') f'Effects not supported in `cond`: {disallowed_effects}')
index = lax.convert_element_type(pred, np.int32) index = lax.convert_element_type(pred, np.int32)
if joined_effects:
# Raise index in case of effects to allow data-dependence-based discharging
# of those effects (even if they don't have an explicit data dependence).
index = core.raise_as_much_as_possible(index)
false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects)
true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects)
@ -352,8 +343,7 @@ def _bcast_select_n(pred, *cases):
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
return lax.select_n(pred, *cases) return lax.select_n(pred, *cases)
def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, def _cond_batching_rule(axis_data, args, dims, branches):
dims, branches):
index, *ops = args index, *ops = args
index_dim, *op_dims = dims index_dim, *op_dims = dims
# TODO(sharadmv): clean this up by adding a specific blocklist # TODO(sharadmv): clean this up by adding a specific blocklist
@ -375,15 +365,13 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
# optimizations to XLA. # optimizations to XLA.
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise! # TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
index, *ops = ( index, *ops = (
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims))
in_batched = [True] * len(branches[0].in_avals) in_batched = [True] * len(branches[0].in_avals)
out_batched = [True] * len(branches[0].out_avals) out_batched = [True] * len(branches[0].out_avals)
branches_batched = [ branches_batched = [
batching.batch_jaxpr( batching.batch_jaxpr(jaxpr, axis_data, in_batched, out_batched)[0]
jaxpr, axis_size, in_batched, out_batched, axis_name, spmd_axis_name,
main_type)[0]
for jaxpr in branches] for jaxpr in branches]
branch_outs = [] branch_outs = []
@ -401,13 +389,11 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
for b, x, d in zip(ops_bat, ops, op_dims)] for b, x, d in zip(ops_bat, ops, op_dims)]
branches_out_bat = [ branches_out_bat = [
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, batching.batch_jaxpr(jaxpr, axis_data, ops_bat, False)[1]
spmd_axis_name, main_type)[1]
for jaxpr in branches] for jaxpr in branches]
out_bat = [any(bat) for bat in zip(*branches_out_bat)] out_bat = [any(bat) for bat in zip(*branches_out_bat)]
branches_batched = tuple( branches_batched = tuple(
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, batching.batch_jaxpr(jaxpr, axis_data, ops_bat, out_bat)[0]
spmd_axis_name, main_type)[0]
for jaxpr in branches) for jaxpr in branches)
out_dims = [0 if b else batching.not_mapped for b in out_bat] out_dims = [0 if b else batching.not_mapped for b in out_bat]
@ -733,12 +719,6 @@ def _cond_transpose(cts, *args, branches):
assert next(out_iter, None) is None assert next(out_iter, None) is None
return [None] + out return [None] + out
def _cond_axis_substitution(params, subst, traverse):
if not traverse:
return params
branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches'])
return dict(params, branches=branches)
def _cond_typecheck(bind_time, *in_atoms, branches): def _cond_typecheck(bind_time, *in_atoms, branches):
if not bind_time: if not bind_time:
_, *in_atoms = in_atoms _, *in_atoms = in_atoms
@ -793,28 +773,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches):
f'called with operands of type {_avals_short(op_avals)}') f'called with operands of type {_avals_short(op_avals)}')
return jaxpr0.out_avals, joined_effects return jaxpr0.out_avals, joined_effects
def cond_bind(*args, branches): cond_p = core.Primitive('cond')
if config.enable_checks.value:
avals = map(core.get_aval, args)
in_atoms = [core.Var('', a) for a in avals] # dummies
_cond_typecheck(True, *in_atoms, branches=branches)
for jaxpr in branches:
core.check_jaxpr(jaxpr.jaxpr)
return core.AxisPrimitive.bind(cond_p, *args, branches=branches)
cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True cond_p.multiple_results = True
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval) cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None)
xla.register_initial_style_primitive(cond_p) xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule pe.dce_rules[cond_p] = _cond_dce_rule
batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule

@ -278,34 +278,30 @@ def _cached_for_jaxpr(jaxpr):
discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
return core.ClosedJaxpr(discharged_jaxpr, body_consts) return core.ClosedJaxpr(discharged_jaxpr, body_consts)
def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, def _for_vmap(axis_data, args, dims, *,
jaxpr, nsteps, reverse, which_linear, unroll): jaxpr, nsteps, reverse, which_linear, unroll):
init_batched = [d is not batching.not_mapped for d in dims] init_batched = [d is not batching.not_mapped for d in dims]
closed_jaxpr = _cached_for_jaxpr(jaxpr) closed_jaxpr = _cached_for_jaxpr(jaxpr)
batched = init_batched batched = init_batched
for _ in range(len(batched)): for _ in range(len(batched)):
_, out_batched = batching.batch_jaxpr( _, out_batched = batching.batch_jaxpr(
closed_jaxpr, closed_jaxpr, axis_data, [False] + batched, instantiate=batched)
axis_size, [False] + batched, instantiate=batched,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if out_batched == batched: if out_batched == batched:
break break
batched = map(operator.or_, batched, out_batched) batched = map(operator.or_, batched, out_batched)
else: else:
raise Exception("Invalid fixpoint") raise Exception("Invalid fixpoint")
args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat
else batching.moveaxis(x, d, 0) if now_bat else x else batching.moveaxis(x, d, 0) if now_bat else x
for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)]
batched_jaxpr_, _ = batching.batch_jaxpr( batched_jaxpr_, _ = batching.batch_jaxpr(
pe.close_jaxpr(jaxpr), axis_size, [False] + batched, [], pe.close_jaxpr(jaxpr), axis_data, [False] + batched, [])
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts
out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps,
reverse=reverse, which_linear=which_linear, reverse=reverse, which_linear=which_linear,
unroll=unroll) unroll=unroll)
return out_flat, [0 if b else batching.not_mapped for b in batched] return out_flat, [0 if b else batching.not_mapped for b in batched]
batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None) batching.fancy_primitive_batchers[for_p] = _for_vmap
batching.spmd_axis_primitive_batchers[for_p] = _for_vmap
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear,
unroll): unroll):

@ -885,7 +885,7 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
b_ys_avals_stripped + res2_avals)) b_ys_avals_stripped + res2_avals))
def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, def _scan_batching_rule(axis_data, args,
dims, reverse, length, dims, reverse, length,
jaxpr, num_consts, num_carry, linear, unroll, jaxpr, num_consts, num_carry, linear, unroll,
_split_transpose): _split_transpose):
@ -902,11 +902,8 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
for _ in range(1 + len(carry_batched)): for _ in range(1 + len(carry_batched)):
batched = const_batched + carry_batched + xs_batched batched = const_batched + carry_batched + xs_batched
jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, axis_size, batched, jaxpr, axis_data, batched,
instantiate=carry_batched + [False] * num_ys, instantiate=carry_batched + [False] * num_ys)
axis_name=axis_name,
spmd_axis_name=spmd_axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched: if carry_batched_out == carry_batched:
break break
@ -919,7 +916,7 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)] else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)] zip(init, init_bdims, init_batched, carry_batched)]
@ -1209,17 +1206,8 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
assert len(refs_out_matching_in_avals) == len(in_avals) assert len(refs_out_matching_in_avals) == len(in_avals)
return refs_out_matching_in_avals, [*carry_out, *ys] return refs_out_matching_in_avals, [*carry_out, *ys]
def scan_bind(*args, **params): scan_p = core.Primitive("scan")
if config.enable_checks.value:
avals = _map(core.get_aval, args)
in_atoms = [core.Var('', a) for a in avals] # dummies
_scan_typecheck(True, *in_atoms, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
return core.AxisPrimitive.bind(scan_p, *args, **params)
scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval) scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp ad.primitive_jvps[scan_p] = _scan_jvp
@ -1228,8 +1216,7 @@ pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p) xla.register_initial_style_primitive(scan_p)
mlir.register_lowering(scan_p, mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True)) mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None) batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule
batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule pe.padding_rules[scan_p] = _scan_padding_rule
@ -1382,8 +1369,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts,
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr,
args, dims, cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr): body_nconsts, body_jaxpr):
from jax._src.callback import _IOEffect, _OrderedIOEffect from jax._src.callback import _IOEffect, _OrderedIOEffect
if any(_OrderedIOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]): if any(_OrderedIOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]):
@ -1401,8 +1387,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
# reach a fixpoint. # reach a fixpoint.
for _ in range(1 + len(carry_bat)): for _ in range(1 + len(carry_bat)):
_, carry_bat_out = batching.batch_jaxpr( _, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, body_jaxpr, axis_data, bconst_bat + carry_bat, instantiate=carry_bat)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if carry_bat == carry_bat_out: if carry_bat == carry_bat_out:
break break
carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out)
@ -1412,8 +1397,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
# Knowing how the carry is batched now, we can determine if the predicate is # Knowing how the carry is batched now, we can determine if the predicate is
# batched. # batched.
_, (pred_bat,) = batching.batch_jaxpr( _, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, cond_jaxpr, axis_data, cconst_bat + carry_bat, instantiate=False)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if pred_bat: if pred_bat:
# If the predicate is batched, we have to batch *all* of the carry # If the predicate is batched, we have to batch *all* of the carry
@ -1424,13 +1408,9 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
carry_bat = [True] * len(carry_bat) carry_bat = [True] * len(carry_bat)
carry_dims = [0] * len(carry_bat) carry_dims = [0] * len(carry_bat)
body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims, body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims)
carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name,
main_type=main_type)
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], cond_jaxpr, axis_data, cconst_dims + carry_dims, [0])
axis_name=axis_name, spmd_axis_name=spmd_axis_name,
main_type=main_type)
else: else:
# If the predicate is not batched, we can look at the `cond_jaxpr`'s out # If the predicate is not batched, we can look at the `cond_jaxpr`'s out
# shape to determine the rank of the predicate. From this rank we pick the # shape to determine the rank of the predicate. From this rank we pick the
@ -1440,13 +1420,11 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
cond_rank = len(cond_jaxpr.out_avals[0].shape) cond_rank = len(cond_jaxpr.out_avals[0].shape)
carry_dims = [cond_rank if b else None for b in carry_bat] carry_dims = [cond_rank if b else None for b in carry_bat]
body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# Now we need to rebatch the `cond_jaxpr` according to the new dims of the # Now we need to rebatch the `cond_jaxpr` according to the new dims of the
# carry. # carry.
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), cond_jaxpr, axis_data, cconst_dims + carry_dims, (None,))
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# To prepare the `init` to the `while_p`, we broadcast values if they are # To prepare the `init` to the `while_p`, we broadcast values if they are
# unbatched and need to have an out axis. If their current batch axis does not # unbatched and need to have an out axis. If their current batch axis does not
@ -1455,7 +1433,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
new_init = [] new_init = []
for x, old_axis, new_axis in zip(init, init_dims, carry_dims): for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
new_init.append(batching.broadcast(x, axis_size, new_axis)) new_init.append(batching.broadcast(x, axis_data.size, new_axis))
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
new_init.append(x) new_init.append(x)
else: else:
@ -1891,7 +1869,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
*[None] * num_carry] *[None] * num_carry]
return invals_out, carry_out return invals_out, carry_out
while_p = core.AxisPrimitive('while') while_p = core.Primitive('while')
while_p.multiple_results = True while_p.multiple_results = True
while_p.def_impl(partial(dispatch.apply_primitive, while_p)) while_p.def_impl(partial(dispatch.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
@ -1899,8 +1877,7 @@ ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p) xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None) batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule
batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering) mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck core.custom_typechecks[while_p] = _while_typecheck

@ -376,8 +376,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
return [None] * sum(const_lengths) + cotangent_b return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
args, dims, const_lengths, jaxprs):
orig_bat = [d is not batching.not_mapped for d in dims] orig_bat = [d is not batching.not_mapped for d in dims]
params, b = _split_linear_solve_args(args, const_lengths) params, b = _split_linear_solve_args(args, const_lengths)
@ -397,15 +396,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x # Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, axis_size, solve_bat + b_bat, instantiate=x_bat, solve, axis_data, solve_bat + b_bat, instantiate=x_bat)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if vecmat is None: if vecmat is None:
vecmat_jaxpr_batched = None vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat x_bat_out = solve_x_bat
else: else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat, vecmat, axis_data, vecmat_bat + b_bat, instantiate=b_bat)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# batch all aux data by default # batch all aux data by default
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
# keep a slice of only the linear operator part of solve's avals # keep a slice of only the linear operator part of solve's avals
@ -413,15 +410,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
# Apply matvec and solve_t -> new batched parts of b # Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat, matvec, axis_data, matvec_bat + x_bat_noaux, instantiate=b_bat)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if solve_t is None: if solve_t is None:
solve_t_jaxpr_batched = None solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else: else:
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, solve_t, axis_data, solve_t_bat + x_bat_noaux, instantiate=x_bat_out)
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
@ -445,7 +440,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
] ]
# Broadcast out b if necessary # Broadcast out b if necessary
new_b = [ new_b = [
batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
] ]
@ -458,7 +453,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
return outs, out_dims return outs, out_dims
linear_solve_p = core.AxisPrimitive('custom_linear_solve') linear_solve_p = core.Primitive('custom_linear_solve')
linear_solve_p.multiple_results = True linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
@ -468,5 +463,4 @@ mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True)) multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None) batching.fancy_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule

@ -1759,6 +1759,9 @@ def stop_gradient(x: T) -> T:
return x return x
elif (dtypes.issubdtype(_dtype(x), np.floating) or elif (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)): dtypes.issubdtype(_dtype(x), np.complexfloating)):
# break abstractions to support legacy leaked tracer use cases
if isinstance(x, ad.JVPTracer):
return stop(x.primal)
return ad_util.stop_gradient_p.bind(x) return ad_util.stop_gradient_p.bind(x)
else: else:
return x return x
@ -2979,14 +2982,18 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
return core._pp_eqn(eqn.replace(params=params), context, settings) return core._pp_eqn(eqn.replace(params=params), context, settings)
convert_element_type_p = Primitive('convert_element_type') convert_element_type_p = Primitive('convert_element_type')
def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding):
operand = core.Primitive.bind(convert_element_type_p, operand, # TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to
new_dtype=new_dtype, weak_type=weak_type, # the old "custom bind" but it might not be the best way to do this.
sharding=sharding) def _convert_element_type_bind_with_trace(trace, args, params):
sharding = params['sharding']
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
if sharding is not None and not config.sharding_in_types.value: if sharding is not None and not config.sharding_in_types.value:
operand = pjit.with_sharding_constraint(operand, sharding) with core.set_current_trace(trace):
operand = pjit.with_sharding_constraint(operand, sharding)
return operand return operand
convert_element_type_p.def_custom_bind(_convert_element_type_bind) convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace)
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval( convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p, partial(standard_abstract_eval, convert_element_type_p,

@ -24,6 +24,7 @@ import math
from jax import tree_util from jax import tree_util
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import sharding_impls from jax._src import sharding_impls
from jax._src.core import AxisName, ShapedArray, raise_to_shaped from jax._src.core import AxisName, ShapedArray, raise_to_shaped
@ -119,8 +120,25 @@ def psum(x, axis_name, *, axis_index_groups=None):
leaves = [lax.convert_element_type(l, np.int32) leaves = [lax.convert_element_type(l, np.int32)
if dtypes.dtype(l) == np.bool_ else l for l in leaves] if dtypes.dtype(l) == np.bool_ else l for l in leaves]
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
out_flat = psum_p.bind( # handle the constant case specially
*leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) if all(not isinstance(leaf, core.Tracer) for leaf in leaves):
named_axes, pos_axes = axes_partition = [], []
for axis in axis_name:
axes_partition[isinstance(axis, int)].append(axis)
def pos_reduce(x):
if not pos_axes:
return x
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
for axis in pos_axes])
if axis_index_groups is not None:
assert not pos_axes
size = len(axis_index_groups[0])
else:
size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes])
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
else:
out_flat = psum_p.bind(
*leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat) return tree_util.tree_unflatten(treedef, out_flat)
def pmean(x, axis_name, *, axis_index_groups=None): def pmean(x, axis_name, *, axis_index_groups=None):
@ -233,7 +251,7 @@ def _axis_index_of_val(x, val, axis_name):
mask = (val == x) mask = (val == x)
validx = lax.select(mask, validx = lax.select(mask,
lax.full(mask.shape, idx), lax.full(mask.shape, idx),
lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtype=idx.dtype)) lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtypes.dtype(idx)))
return pmin(validx, axis_name) return pmin(validx, axis_name)
def _validate_reduce_axis_index_groups(axis_index_groups): def _validate_reduce_axis_index_groups(axis_index_groups):
@ -303,6 +321,8 @@ def ppermute(x, axis_name, perm):
Array(s) with the same shape as ``x`` with slices along the axis Array(s) with the same shape as ``x`` with slices along the axis
``axis_name`` gathered from ``x`` according to the permutation ``perm``. ``axis_name`` gathered from ``x`` according to the permutation ``perm``.
""" """
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
return tree_util.tree_map( return tree_util.tree_map(
partial(ppermute_p.bind, axis_name=axis_name, partial(ppermute_p.bind, axis_name=axis_name,
perm=tuple(map(tuple, perm))), x) perm=tuple(map(tuple, perm))), x)
@ -472,8 +492,15 @@ def axis_index(axis_name):
[0 1] [0 1]
[0 1]] [0 1]]
""" """
return axis_index_p.bind(axis_name=axis_name) if not isinstance(axis_name, (tuple, list)):
return axis_index_p.bind(axis_name=axis_name)
else:
inner_size = 1
index = 0
for name in reversed(axis_name):
index += axis_index(name) * inner_size
inner_size *= psum(1, name)
return index
def pgather(src, idx, axes: int | AxisName): def pgather(src, idx, axes: int | AxisName):
"""Uses the last positional axis of idx to index into src's axes.""" """Uses the last positional axis of idx to index into src's axes."""
@ -485,18 +512,30 @@ def pgather(src, idx, axes: int | AxisName):
### parallel primitives ### parallel primitives
def _subst_all_names_in_param( def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]:
pname: str, params: core.ParamDict, subst: core.AxisSubst, traverse: bool) -> core.ParamDict: axis_names = params[pname]
axis_name = params[pname] if isinstance(axis_names, (tuple, list)):
if not isinstance(axis_name, (tuple, list)): return tuple(axis_names)
axis_name = (axis_name,) else:
result = dict(params) return (axis_names,)
result[pname] = sum(((name,) if isinstance(name, int) else subst(name)
for name in axis_name),
())
return result
def _reduction_with_positional_batcher(prim, vals_in, dims_in, axis_index_groups, def _constant_reduction(prim, axis_data, args, axes, axis_index_groups):
assert axis_data.name in axes
if axis_index_groups: raise NotImplementedError
new_axes = tuple(n for n in axes if n != axis_data.name)
if new_axes:
args = prim.bind(*args, axes=new_axes, axis_index_groups=axis_index_groups)
if prim is psum_p:
outs = [lax._const(x, axis_data.size) * x for x in args]
elif prim in (pmin_p, pmax_p):
outs = args
else:
raise Exception(f"Unrecognized reducer: {prim}")
return outs, [None] * len(outs)
def _reduction_with_positional_batcher(
prim, vals_in, dims_in, axis_index_groups,
transform_unmapped, transform_mapped): transform_unmapped, transform_mapped):
if axis_index_groups is not None: if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not supported in vmap collectives. " raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
@ -536,10 +575,19 @@ def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups):
return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in] return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in]
def _batched_reduction_collective( def _batched_reduction_collective(
prim, if_unmapped, axis_size, frame_name, _, vals_in, dims_in, axes, prim, if_unmapped, axis_data, vals_in, dims_in, axes,
axis_index_groups): axis_index_groups):
assert prim.multiple_results assert prim.multiple_results
assert frame_name in axes if all(d is None for d in dims_in):
if axis_data.name in axes:
return _constant_reduction(prim, axis_data, vals_in, axes, axis_index_groups)
else:
return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in
if axis_data.name not in axes:
return _reduction_batcher(prim, vals_in, dims_in, axes=axes,
axis_index_groups=axis_index_groups)
# Note that we have a choice here. We can either unfuse the reduction into one # Note that we have a choice here. We can either unfuse the reduction into one
# that handles the batched dims and then another one that handles the rest. # that handles the batched dims and then another one that handles the rest.
# Alternatively, we can keep the dimension reduction fused with the rest, but # Alternatively, we can keep the dimension reduction fused with the rest, but
@ -548,12 +596,11 @@ def _batched_reduction_collective(
# We choose the second strategy here. # We choose the second strategy here.
vals_out = _reduction_with_positional_batcher( vals_out = _reduction_with_positional_batcher(
prim, vals_in, dims_in, axis_index_groups, prim, vals_in, dims_in, axis_index_groups,
lambda d, d_vals_in: (tuple(axis for axis in axes if axis != frame_name), lambda d, d_vals_in: (tuple(axis for axis in axes if axis != axis_data.name),
[if_unmapped(v, axis_size) for v in d_vals_in]), [if_unmapped(v, axis_data.size) for v in d_vals_in]),
lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else
axis if axis != frame_name else axis if axis != axis_data.name else
d d for axis in axes),
for axis in axes),
d_vals_in)) d_vals_in))
return vals_out, [batching.not_mapped] * len(vals_out) return vals_out, [batching.not_mapped] * len(vals_out)
@ -572,12 +619,16 @@ def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]
dtype=np.int64).T dtype=np.int64).T
return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups))
def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups): def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups):
assert axis_index_groups is None assert axis_index_groups is None
if not all(isinstance(axis, int) for axis in axes):
return dispatch.apply_primitive(prim, *args, axes=axes,
axis_index_groups=axis_index_groups)
assert all(isinstance(axis, int) for axis in axes) assert all(isinstance(axis, int) for axis in axes)
return [pos_reducer(arg, axes) for arg in args] return [pos_reducer(arg, axes) for arg in args]
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
_check_axis_names(axes)
named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
if axis_index_groups is not None: if axis_index_groups is not None:
@ -589,6 +640,13 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
arg.dtype) for arg in args] arg.dtype) for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _check_axis_names(axes):
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
axis_env = core.get_axis_env()
for name in named_axes:
if not axis_env.axis_exists(name):
raise NameError(f"unbound axis name: {name}")
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
len_0 = len(axis_index_groups[0]) len_0 = len(axis_index_groups[0])
@ -669,64 +727,37 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
axis_index_groups=axis_index_groups) axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, nonzero_in_cts) return tree_util.tree_unflatten(treedef, nonzero_in_cts)
psum_p = core.AxisPrimitive('psum') psum_p = core.Primitive('psum')
psum_p.multiple_results = True psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum))
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering( mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule) ad.deflinear2(psum_p, _psum_transpose_rule)
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p) batching.fancy_primitive_batchers[psum_p] = \
batching.axis_primitive_batchers[psum_p] = \
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v) partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes') batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes')
pmax_p = core.Primitive('pmax')
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axes, axis_index_groups):
if all(not isinstance(x, core.Tracer) for x in args):
named_axes, pos_axes = axes_partition = [], []
for axis in axes:
axes_partition[isinstance(axis, int)].append(axis)
def pos_reduce(x):
if not pos_axes:
return x
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
for axis in pos_axes])
if axis_index_groups is not None:
assert not pos_axes
size = len(axis_index_groups[0])
else:
size = math.prod([core.axis_frame(name).size for name in named_axes])
return tuple(lax._const(x, size) * pos_reduce(x) for x in args)
return core.AxisPrimitive.bind(
psum_p, *args, axes=axes, axis_index_groups=axis_index_groups)
pmax_p = core.AxisPrimitive('pmax')
pmax_p.multiple_results = True pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max))
pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering( mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p) batching.fancy_primitive_batchers[pmax_p] = \
batching.axis_primitive_batchers[pmax_p] = \
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v) partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes') batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes')
pmin_p = core.AxisPrimitive('pmin') pmin_p = core.Primitive('pmin')
pmin_p.multiple_results = True pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min))
pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering( mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p) batching.fancy_primitive_batchers[pmin_p] = \
batching.axis_primitive_batchers[pmin_p] = \
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v) partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes') batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes')
def _ppermute_lowering(ctx, x, *, axis_name, perm): def _ppermute_lowering(ctx, x, *, axis_name, perm):
@ -765,15 +796,16 @@ def _ppermute_transpose_rule(t, x, perm, axis_name):
inverse_perm = list(zip(dsts, srcs)) inverse_perm = list(zip(dsts, srcs))
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, perm): def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
axis_size, frame_name = axis_data.size, axis_data.name
(v,), (d,) = vals_in, dims_in (v,), (d,) = vals_in, dims_in
if not isinstance(axis_name, (tuple, list)): if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,) axis_name = (axis_name,)
if axis_data.name not in axis_name:
return ppermute_p.bind(v, perm=perm, axis_name=axis_name), d
remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) remaining_axes = tuple(axis for axis in axis_name if axis != frame_name)
if axis_size == 1 and remaining_axes:
return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d
if remaining_axes: if remaining_axes:
raise NotImplementedError("ppermute batcher only supports a single axis") return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d
assert axis_name[0] == frame_name, "ppermute batcher called with a wrong axis!" assert axis_name[0] == frame_name, "ppermute batcher called with a wrong axis!"
assert len(perm) == axis_size, "Permutation doesn't match the axis size!" assert len(perm) == axis_size, "Permutation doesn't match the axis size!"
if d is batching.not_mapped: if d is batching.not_mapped:
@ -783,30 +815,33 @@ def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, per
perm_indices[dst] = src perm_indices[dst] = src
return v.take(perm_indices, d), d return v.take(perm_indices, d), d
def _collective_batcher(prim, args, dims, **params): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] _check_axis_names(axis_name)
return raise_to_shaped(x)
ppermute_p = core.AxisPrimitive('ppermute') ppermute_p = core.Primitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
ad.deflinear2(ppermute_p, _ppermute_transpose_rule) ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
mlir.register_lowering(ppermute_p, _ppermute_lowering) mlir.register_lowering(ppermute_p, _ppermute_lowering)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name')
core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name')
def _pbroadcast_transpose_rule(t, x, source, axis_name): def _pbroadcast_transpose_rule(t, x, source, axis_name):
is_source = axis_index(axis_name) == source is_source = axis_index(axis_name) == source
tsum = psum(t, axis_name) tsum = psum(t, axis_name)
return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))]
def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): def _pbroadcast_batcher(axis_data, vals_in, dims_in, axis_name, source):
axis_size = axis_data.size
(v,), (d,) = vals_in, dims_in (v,), (d,) = vals_in, dims_in
if not isinstance(axis_name, (tuple, list)): if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,) axis_name = (axis_name,)
remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) if axis_data.name not in axis_name:
return pbroadcast_p.bind(v, axis_name=axis_name, source=source), d
remaining_axes = tuple(axis for axis in axis_name if axis != axis_data.name)
if remaining_axes: if remaining_axes:
raise NotImplementedError("pbroadcast batcher only supports a single axis") raise NotImplementedError("pbroadcast batcher only supports a single axis")
assert axis_name[0] == frame_name, "pbroadcast batcher called with a wrong axis!" assert axis_name[0] == axis_data.name, "pbroadcast batcher called with a wrong axis!"
assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!" assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!"
if axis_size == 1 and remaining_axes: if axis_size == 1 and remaining_axes:
return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d
@ -823,13 +858,12 @@ def _pbroadcast_lowering(ctx, x, *, axis_name, source):
return hlo.CollectiveBroadcastOp( return hlo.CollectiveBroadcastOp(
x, replica_groups=_replica_groups_hlo(replica_groups)).results x, replica_groups=_replica_groups_hlo(replica_groups)).results
pbroadcast_p = core.AxisPrimitive('pbroadcast') pbroadcast_p = core.Primitive('pbroadcast')
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule)
mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering)
batching.primitive_batchers[pbroadcast_p] = partial(_collective_batcher, pbroadcast_p) batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name')
core.axis_substitution_rules[pbroadcast_p] = partial(_subst_all_names_in_param, 'axis_name')
def _moveaxis(src, dst, x): def _moveaxis(src, dst, x):
@ -914,11 +948,22 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis,
) )
return result, d return result, d
def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, def _all_to_all_batched_collective(axis_data, vals_in, dims_in,
axis_name, split_axis, concat_axis, axis_name, split_axis, concat_axis,
axis_index_groups, tiled): axis_index_groups, tiled):
axis_size, frame_name = axis_data.size, axis_data.name
if axis_index_groups is not None: if axis_index_groups is not None:
raise NotImplementedError("Please open a feature request!") raise NotImplementedError("Please open a feature request!")
if isinstance(axis_name, (list, tuple)):
axes_names = axis_name
else:
axes_names = [axis_name]
if axis_data.name not in axes_names:
return _all_to_all_batcher(
vals_in, dims_in, axis_name=axis_name, split_axis=split_axis,
concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled)
x, = vals_in x, = vals_in
d, = dims_in d, = dims_in
if d is batching.not_mapped: if d is batching.not_mapped:
@ -979,6 +1024,7 @@ def _all_to_all_effectful_abstract_eval(
del tiled # expand_dims and squeeze is done in `all_to_all` if `True` del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
if not isinstance(axis_name, (list, tuple)): if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,) axis_name = (axis_name,)
_check_axis_names(axis_name)
input_aval = raise_to_shaped(x) input_aval = raise_to_shaped(x)
shape = list(input_aval.shape) shape = list(input_aval.shape)
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
@ -990,13 +1036,12 @@ def _all_to_all_effectful_abstract_eval(
return out_aval, effects return out_aval, effects
all_to_all_p = core.AxisPrimitive('all_to_all') all_to_all_p = core.Primitive('all_to_all')
all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval) all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering) mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
batching.axis_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name')
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
@ -1063,6 +1108,8 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
[[12 13 14 15] [[12 13 14 15]
[ 4 5 6 7]]] [ 4 5 6 7]]]
""" """
if not isinstance(axis_name, tuple):
axis_name = axis_name,
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
def bind(leaf): def bind(leaf):
@ -1071,7 +1118,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
all_gather_dimension=canonicalize_axis( all_gather_dimension=canonicalize_axis(
axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1),
axis_name=axis_name, axis_index_groups=axis_index_groups, axis_name=axis_name, axis_index_groups=axis_index_groups,
axis_size=axis_size, tiled=tiled) axis_size=int(axis_size), tiled=tiled)
return tree_util.tree_map(bind, x) return tree_util.tree_map(bind, x)
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
@ -1126,6 +1173,7 @@ def _all_gather_effectful_abstract_eval(
): ):
if not isinstance(axis_name, (list, tuple)): if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,) axis_name = (axis_name,)
_check_axis_names(axis_name)
x_aval = raise_to_shaped(x) x_aval = raise_to_shaped(x)
new_shape = list(x_aval.shape) new_shape = list(x_aval.shape)
if tiled: if tiled:
@ -1144,10 +1192,11 @@ def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_
def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
(x,), (d,) = vals_in, dims_in (x,), (d,) = vals_in, dims_in
if d <= all_gather_dimension: if d is not batching.not_mapped:
all_gather_dimension += 1 if d <= all_gather_dimension:
elif not tiled: # Tiled all-gather doesn't modify the set of dimensions all_gather_dimension += 1
d += 1 elif not tiled: # Tiled all-gather doesn't modify the set of dimensions
d += 1
result = all_gather_p.bind( result = all_gather_p.bind(
x, x,
all_gather_dimension=all_gather_dimension, all_gather_dimension=all_gather_dimension,
@ -1157,9 +1206,15 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax
tiled=tiled) tiled=tiled)
return result, d return result, d
def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, def _all_gather_batched_collective(axis_data, vals_in, dims_in,
all_gather_dimension, axis_name, all_gather_dimension, axis_name,
axis_index_groups, axis_size, tiled): axis_index_groups, axis_size, tiled):
frame_size, frame_name = axis_data.size, axis_data.name
if frame_name not in axis_name:
return _all_gather_batcher(
vals_in, dims_in, all_gather_dimension=all_gather_dimension,
axis_name=axis_name, axis_index_groups=axis_index_groups,
axis_size=axis_size, tiled=tiled)
if axis_index_groups is not None: if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not supported in vmap") raise NotImplementedError("axis_index_groups not supported in vmap")
assert axis_size == frame_size, "axis size doesn't match" assert axis_size == frame_size, "axis size doesn't match"
@ -1180,7 +1235,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
y = _foldaxis(all_gather_dimension, y) y = _foldaxis(all_gather_dimension, y)
return y, batching.not_mapped return y, batching.not_mapped
all_gather_p = core.AxisPrimitive('all_gather') all_gather_p = core.Primitive('all_gather')
all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval) all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval)
all_gather_p.def_impl(_all_gather_impl) all_gather_p.def_impl(_all_gather_impl)
mlir.register_lowering(all_gather_p, _all_gather_lowering) mlir.register_lowering(all_gather_p, _all_gather_lowering)
@ -1189,9 +1244,8 @@ for p in ("cuda", "rocm", "tpu"):
partial(_all_gather_lowering, platform=p), partial(_all_gather_lowering, platform=p),
platform=p) platform=p)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule) ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective
batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name')
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
def _reduce_scatter_lowering( def _reduce_scatter_lowering(
@ -1248,6 +1302,7 @@ def _reduce_scatter_effectful_abstract_eval(
): ):
if not isinstance(axis_name, (list, tuple)): if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,) axis_name = (axis_name,)
_check_axis_names(axis_name)
x_aval = core.raise_to_shaped(x) x_aval = core.raise_to_shaped(x)
new_shape = list(x_aval.shape) new_shape = list(x_aval.shape)
scatter_dim_input_size = x_aval.shape[scatter_dimension] scatter_dim_input_size = x_aval.shape[scatter_dimension]
@ -1289,9 +1344,15 @@ def _reduce_scatter_batcher(vals_in, dims_in, *, scatter_dimension, axis_name,
tiled=tiled) tiled=tiled)
return result, d return result, d
def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, def _reduce_scatter_collective(axis_data, vals_in, dims_in,
scatter_dimension, axis_name, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled): axis_index_groups, axis_size, tiled):
frame_size, frame_name = axis_data.size, axis_data.name
if frame_name not in axis_name:
return _reduce_scatter_batcher(
vals_in, dims_in, scatter_dimension=scatter_dimension,
axis_name=axis_name, axis_index_groups=axis_index_groups,
axis_size=axis_size, tiled=tiled)
if axis_index_groups is not None: if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not supported in vmap") raise NotImplementedError("axis_index_groups not supported in vmap")
assert axis_size == frame_size, "axis size doesn't match" assert axis_size == frame_size, "axis size doesn't match"
@ -1310,21 +1371,17 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
return y, dy return y, dy
reduce_scatter_p = core.AxisPrimitive("reduce_scatter") reduce_scatter_p = core.Primitive("reduce_scatter")
reduce_scatter_p.def_effectful_abstract_eval( reduce_scatter_p.def_effectful_abstract_eval(
_reduce_scatter_effectful_abstract_eval _reduce_scatter_effectful_abstract_eval
) )
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher batching.fancy_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective batching.skippable_batchers[reduce_scatter_p] = partial(_names_in_param, 'axis_name')
mlir.register_lowering(reduce_scatter_p, mlir.register_lowering(reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p)) partial(_reduce_scatter_lowering, lax.add_p))
core.axis_substitution_rules[reduce_scatter_p] = \
partial(_subst_all_names_in_param, 'axis_name')
def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
tiled=False): tiled=False):
""" """
@ -1401,6 +1458,8 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
[12 14] [12 14]
[16 18]] [16 18]]
""" """
if not isinstance(axis_name, tuple):
axis_name = axis_name,
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
bind = partial( bind = partial(
@ -1420,6 +1479,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
raise NotImplementedError( raise NotImplementedError(
'`axis_index` translation rule does not support multiple axis names.') '`axis_index` translation rule does not support multiple axis names.')
axis_name, = axis_name axis_name, = axis_name
if axis_name not in axis_env.names:
raise NameError(f"unbound axis name: {axis_name}")
axis_pos = list(axis_env.names).index(axis_name) axis_pos = list(axis_env.names).index(axis_name)
nreplicas = axis_env.nreps // math.prod(axis_env.sizes) nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
div = mlir.ir_constant( div = mlir.ir_constant(
@ -1443,51 +1504,22 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
unsigned_index) unsigned_index)
def _axis_index_lowering(ctx, *, axis_name): def _axis_index_lowering(ctx, *, axis_name):
return [ return [_build_axis_index_lowering_hlo(ctx, axis_name,
_build_axis_index_lowering_hlo(ctx, axis_name, ctx.module_context.axis_env)]
ctx.module_context.axis_env)
]
def _axis_index_effectful_abstract_eval(*, axis_name): def _axis_index_effectful_abstract_eval(*, axis_name):
frame = core.axis_frame(axis_name) _check_axis_names([axis_name])
return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)}
def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):
return lax.iota(np.int32, axis_data.size), 0
axis_index_p = core.Primitive('axis_index') axis_index_p = core.Primitive('axis_index')
axis_index_p.def_impl(partial(dispatch.apply_primitive, axis_index_p))
mlir.register_lowering(axis_index_p, _axis_index_lowering) mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval) axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name') batching.fancy_primitive_batchers[axis_index_p] = _axis_index_batcher
batching.skippable_batchers[axis_index_p] = partial(_names_in_param, 'axis_name')
# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
def name_idx(name):
frame = core.axis_frame(name)
dynamic = core.thread_local_state.trace_state.trace_stack.dynamic
if (frame.main_trace is None or dynamic.level > frame.main_trace.level):
return core.Primitive.bind(axis_index_p, axis_name=name)
else:
trace = frame.main_trace.with_cur_sublevel()
return trace.process_axis_index(frame)
if not isinstance(axis_name, (tuple, list)):
return name_idx(axis_name)
else:
inner_size = 1
index = 0
for name in reversed(axis_name):
index += name_idx(name) * inner_size
inner_size *= psum(1, name)
return index
axis_index_p.def_custom_bind(_axis_index_bind)
def _vmap_process_axis_index(self, frame):
assert frame.size is not None
return batching.BatchTracer(self, lax.iota(np.int32, frame.size), 0)
batching.BatchTrace.process_axis_index = _vmap_process_axis_index # type: ignore
def _pgather_impl(src, idx, *, axes): def _pgather_impl(src, idx, *, axes):
assert all(isinstance(axis, int) for axis in axes) assert all(isinstance(axis, int) for axis in axes)
@ -1508,6 +1540,7 @@ def _pgather_impl(src, idx, *, axes):
def _pgather_abstract_eval(src, idx, *, axes): def _pgather_abstract_eval(src, idx, *, axes):
# TODO: Avals with names rule: remove all axes from src, insert those from idx # TODO: Avals with names rule: remove all axes from src, insert those from idx
# The order is important, because it is ok to re-insert one of the deleted axes! # The order is important, because it is ok to re-insert one of the deleted axes!
_check_axis_names(axes)
shape = list(src.shape) shape = list(src.shape)
for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True): for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True):
del shape[axis] del shape[axis]
@ -1559,11 +1592,10 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
else: else:
return pgather_p.bind(src, idx, axes=new_axes), batching.not_mapped return pgather_p.bind(src, idx, axes=new_axes), batching.not_mapped
pgather_p = core.AxisPrimitive('pgather') pgather_p = core.Primitive('pgather')
pgather_p.def_impl(_pgather_impl) pgather_p.def_impl(_pgather_impl)
pgather_p.def_abstract_eval(_pgather_abstract_eval) pgather_p.def_abstract_eval(_pgather_abstract_eval)
mlir.register_lowering(pgather_p, _pgather_parallel_lowering) mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
# TODO: Transpose? That requires adding pscatter... # TODO: Transpose? That requires adding pscatter...
batching.primitive_batchers[pgather_p] = _pgather_batcher batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher
batching.axis_primitive_batchers[pgather_p] = _pgather_collective_batcher batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes')
core.axis_substitution_rules[pgather_p] = partial(_subst_all_names_in_param, 'axes')

@ -64,14 +64,12 @@ data must be immutable, because it will be stored in function memoization tables
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
from typing import Any, NamedTuple from typing import Any, NamedTuple
import weakref import weakref
from jax._src import config from jax._src import config
from jax._src import core from jax._src import core
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.tree_util import tree_map
from jax._src.util import curry, cache_clearing_funs from jax._src.util import curry, cache_clearing_funs
@ -337,13 +335,8 @@ def cache(call: Callable, *, explain: Callable | None = None):
def memoized_fun(fun: WrappedFun, *args): def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore
if config.check_tracer_leaks.value: key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, config.default_device.value, config.trace_context())
config.enable_x64.value, config.default_device.value,
config.trace_context())
else:
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
config.default_device.value, config.trace_context())
result = cache.get(key, None) result = cache.get(key, None)
if result is not None: if result is not None:
ans, stores = result ans, stores = result
@ -364,17 +357,6 @@ def cache(call: Callable, *, explain: Callable | None = None):
cache_clearing_funs.add(memoized_fun.cache_clear) cache_clearing_funs.add(memoized_fun.cache_clear)
return memoized_fun return memoized_fun
def _copy_main_trace(x):
if isinstance(x, core.MainTrace):
return core.MainTrace(x.level, x.trace_type, **x.payload)
else:
return x
_copy_main_traces = partial(tree_map, _copy_main_trace)
@transformation @transformation
def hashable_partial(*args): def hashable_partial(*args):
yield (yield args, {}) yield (yield args, {})

@ -607,7 +607,6 @@ def __array_module__(self, types):
return NotImplemented return NotImplemented
@core.stash_axis_env()
@partial(jax.jit, static_argnums=(1,2,3)) @partial(jax.jit, static_argnums=(1,2,3))
def _multi_slice(self: Array, def _multi_slice(self: Array,
start_indices: tuple[tuple[int, ...]], start_indices: tuple[tuple[int, ...]],

@ -1142,14 +1142,3 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh):
effs.add(eff) effs.add(eff)
return [], effs return [], effs
jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
def _core_map_axis_subst(params, subst, traverse):
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with jax_core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst

@ -593,14 +593,14 @@ def dma_start_discharge_rule(in_avals, out_avals,
# Note that this code only works in SPMD mode. If not all devices execute # Note that this code only works in SPMD mode. If not all devices execute
# the DMA then the devices that do will hang. # the DMA then the devices that do will hang.
# TODO(justinfu): Verify that code only works in SPMD mode. # TODO(justinfu): Verify that code only works in SPMD mode.
axis_env = jax_core.thread_local_state.trace_state.axis_env axis_env = jax_core.get_axis_env()
nonempty_axes = [frame for frame in axis_env if frame.name is not None] nonempty_axes = [name for name in axis_env.axis_sizes if name is not None]
if device_id_type == DeviceIdType.LOGICAL: if device_id_type == DeviceIdType.LOGICAL:
if len(nonempty_axes) > 1: if len(nonempty_axes) > 1:
raise NotImplementedError("Sharding with more than one named axis not " raise NotImplementedError("Sharding with more than one named axis not "
"implemented in dma_start_p for LOGICAL " "implemented in dma_start_p for LOGICAL "
"device_id_type.") "device_id_type.")
shard_axis = nonempty_axes[0].name shard_axis = nonempty_axes[0]
my_axis = jax.lax.axis_index(shard_axis) my_axis = jax.lax.axis_index(shard_axis)
elif device_id_type == DeviceIdType.MESH: elif device_id_type == DeviceIdType.MESH:
device_id_len = 1 device_id_len = 1
@ -608,9 +608,9 @@ def dma_start_discharge_rule(in_avals, out_avals,
device_id_len = device_id.size device_id_len = device_id.size
elif hasattr(device_id, '__len__'): elif hasattr(device_id, '__len__'):
device_id_len = len(device_id) device_id_len = len(device_id)
if device_id_len != len(axis_env): if device_id_len != len(axis_env.axis_sizes):
raise ValueError( raise ValueError(
f"device_id ({device_id_len}) and mesh ({len(axis_env)}) " f"device_id ({device_id_len}) and mesh ({len(axis_env.axis_sizes)}) "
"must have same length.") "must have same length.")
if device_id_len > 1 or len(nonempty_axes) > 1: if device_id_len > 1 or len(nonempty_axes) > 1:
raise NotImplementedError("Meshes with more than 1 named dimension not " raise NotImplementedError("Meshes with more than 1 named dimension not "

@ -68,8 +68,8 @@ def program_id(axis: int) -> jax.Array:
""" """
return program_id_p.bind(axis=axis) return program_id_p.bind(axis=axis)
@program_id_p.def_custom_bind def program_id_bind_with_trace(trace, _, params):
def program_id_bind(*, axis: int): axis = params.pop("axis")
grid_env = pallas_core.current_grid_env() grid_env = pallas_core.current_grid_env()
if grid_env: if grid_env:
return grid_env[axis].index return grid_env[axis].index
@ -77,7 +77,9 @@ def program_id_bind(*, axis: int):
# Query the size of the axis to make sure it's a valid axis (and error # Query the size of the axis to make sure it's a valid axis (and error
# otherwise). # otherwise).
_ = frame.size(axis) _ = frame.size(axis)
return jax_core.Primitive.bind(program_id_p, axis=axis) return jax_core.Primitive.bind_with_trace(program_id_p, trace, (), dict(axis=axis))
# TODO(dougalm): figure out how put the grid_env contest on the relevant trace
program_id_p.def_bind_with_trace(program_id_bind_with_trace)
@program_id_p.def_abstract_eval @program_id_p.def_abstract_eval
def _program_id_abstract_eval(**_): def _program_id_abstract_eval(**_):
@ -89,8 +91,8 @@ def num_programs(axis: int) -> int | jax.Array:
"""Returns the size of the grid along the given axis.""" """Returns the size of the grid along the given axis."""
return num_programs_p.bind(axis=axis) return num_programs_p.bind(axis=axis)
@num_programs_p.def_custom_bind def _num_programs_bind_with_trace(trace, _, params):
def _num_programs_bind(*, axis: int): axis = params.pop("axis")
# We might be using a local grid env # We might be using a local grid env
grid_env = pallas_core.current_grid_env() grid_env = pallas_core.current_grid_env()
if grid_env: if grid_env:
@ -99,8 +101,9 @@ def _num_programs_bind(*, axis: int):
frame = pallas_core.axis_frame() frame = pallas_core.axis_frame()
size = frame.size(axis) size = frame.size(axis)
if size is pallas_core.dynamic_grid_dim: if size is pallas_core.dynamic_grid_dim:
return jax_core.Primitive.bind(num_programs_p, axis=axis) return jax_core.Primitive.bind_with_trace(num_programs_p, trace, (), dict(axis=axis))
return size return size
num_programs_p.def_bind_with_trace(_num_programs_bind_with_trace)
@num_programs_p.def_abstract_eval @num_programs_p.def_abstract_eval
def _num_programs_abstract_eval(**_): def _num_programs_abstract_eval(**_):

@ -1437,7 +1437,7 @@ def check_aval_layout_compatibility(
# -------------------- pjit rules -------------------- # -------------------- pjit rules --------------------
pjit_p = core.AxisPrimitive("pjit") pjit_p = core.Primitive("pjit")
pjit_p.multiple_results = True pjit_p.multiple_results = True
@ -1786,8 +1786,9 @@ def pjit_staging_rule(trace, *args, **params):
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
# but redundantly performs abstract evaluation again. # but redundantly performs abstract evaluation again.
out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, with core.set_current_trace(trace):
propagate_source_info=False) out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False)
else: else:
out_tracers = pe.inline_jaxpr_into_trace( out_tracers = pe.inline_jaxpr_into_trace(
trace, jaxpr.jaxpr, jaxpr.consts, *args) trace, jaxpr.jaxpr, jaxpr.consts, *args)
@ -1807,7 +1808,7 @@ def pjit_staging_rule(trace, *args, **params):
trace.frame.add_eqn(eqn) trace.frame.add_eqn(eqn)
elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, consts = pxla._move_mutable_consts(jaxpr) jaxpr, consts = pxla._move_mutable_consts(jaxpr)
consts = map(trace.instantiate_const, consts) consts = map(trace.new_const, consts)
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
in_layouts = (*params['in_layouts'],) + (None,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
donated_invars = (*params['donated_invars'],) + (False,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
@ -1936,14 +1937,11 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
mlir.register_lowering(pjit_p, _pjit_lowering) mlir.register_lowering(pjit_p, _pjit_lowering)
def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, def _pjit_batcher(axis_data, vals_in, dims_in,
vals_in, dims_in, jaxpr, in_shardings, out_shardings, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
in_layouts, out_layouts, resource_env, donated_invars, name, resource_env, donated_invars, name, keep_unused, inline):
keep_unused, inline):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2( new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
jaxpr, axis_size, dims_in, axis_name=axis_name,
spmd_axis_name=spmd_axis_name, main_type=main_type)
if resource_env is not None: if resource_env is not None:
mesh = resource_env.physical_mesh mesh = resource_env.physical_mesh
@ -1952,11 +1950,11 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
in_shardings = tuple( in_shardings = tuple(
_pjit_batcher_for_sharding(i, axis_in, spmd_axis_name, mesh, aval.ndim) _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim)
if axis_in is not None else i if axis_in is not None else i
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
out_shardings = tuple( out_shardings = tuple(
_pjit_batcher_for_sharding(o, axis_out, spmd_axis_name, mesh, aval.ndim) _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim)
if axis_out is not None else o if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
# TODO(yashkatariya): Figure out layouts should change under vmap. # TODO(yashkatariya): Figure out layouts should change under vmap.
@ -1982,8 +1980,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
vals_in, vals_out, axes_out) vals_in, vals_out, axes_out)
return vals_out, resolved_axes_out return vals_out, resolved_axes_out
batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
def _pjit_batcher_for_sharding( def _pjit_batcher_for_sharding(
@ -2541,24 +2538,23 @@ mlir.register_lowering(sharding_constraint_p,
def _sharding_constraint_batcher( def _sharding_constraint_batcher(
spmd_axis_name, axis_size, axis_name, main_type, vals_in, axis_data, vals_in, dims_in, sharding, layout, resource_env, unconstrained_dims):
dims_in, sharding, layout, resource_env, unconstrained_dims): if axis_data.spmd_name is not None and isinstance(sharding, NamedSharding):
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
used = {n for ns in sharding.spec used = {n for ns in sharding.spec
for n in (ns if isinstance(ns, tuple) else (ns,))} for n in (ns if isinstance(ns, tuple) else (ns,))}
if set(spmd_axis_name) & used: if set(axis_data.spmd_name) & used:
raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in " raise ValueError(f"vmap spmd_axis_name {axis_data.spmd_name} cannot appear in "
"with_sharding_constraint spec, but got spec " "with_sharding_constraint spec, but got spec "
f"{sharding.spec}") f"{sharding.spec}")
x, = vals_in x, = vals_in
d, = dims_in d, = dims_in
# None means unconstrained in ParsedPartitionSpec
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims} unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
if spmd_axis_name is None: if axis_data.spmd_name is None:
unconstrained_dims.add(d) unconstrained_dims.add(d)
vmapped_sharding = _pjit_batcher_for_sharding( vmapped_sharding = _pjit_batcher_for_sharding(
sharding, d, spmd_axis_name, resource_env.physical_mesh, x.ndim) sharding, d, axis_data.spmd_name, resource_env.physical_mesh, x.ndim)
if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding): if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding):
new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec)) new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec))
for u in unconstrained_dims: for u in unconstrained_dims:
@ -2579,9 +2575,9 @@ def _sharding_constraint_batcher(
resource_env=resource_env, resource_env=resource_env,
unconstrained_dims=unconstrained_dims) unconstrained_dims=unconstrained_dims)
return y, d return y, d
batching.spmd_axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher
batching.axis_primitive_batchers[sharding_constraint_p] = partial( batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
_sharding_constraint_batcher, None)
# -------------------- helpers -------------------- # -------------------- helpers --------------------

@ -23,7 +23,6 @@ from typing import Any, Protocol, TypeVar
from jax._src import ad_util from jax._src import ad_util
from jax._src import api_util from jax._src import api_util
from jax._src import config
from jax._src import core from jax._src import core
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import source_info_util from jax._src import source_info_util
@ -478,20 +477,6 @@ def _closed_call_discharge_rule(
run_state_p = core.Primitive("run_state") run_state_p = core.Primitive("run_state")
run_state_p.multiple_results = True run_state_p.multiple_results = True
def _run_state_bind(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...],
is_initialized: tuple[bool, ...]):
if config.enable_checks.value:
core.check_jaxpr(jaxpr)
num_uninitialized = sum(not i for i in is_initialized)
assert len(jaxpr.invars) == len(args) + num_uninitialized
assert len(which_linear) == len(args) + num_uninitialized
return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr,
which_linear=which_linear,
is_initialized=is_initialized)
run_state_p.def_custom_bind(_run_state_bind)
def _default_initialization(x): def _default_initialization(x):
assert hasattr(x, 'shape') assert hasattr(x, 'shape')
assert hasattr(x, 'dtype') assert hasattr(x, 'dtype')
@ -502,7 +487,6 @@ def _default_initialization(x):
value = math.nan value = math.nan
return lax.full(x.shape, value, dtype) return lax.full(x.shape, value, dtype)
def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, def _run_state_impl(*args: Any, jaxpr: core.Jaxpr,
which_linear: tuple[bool, ...], which_linear: tuple[bool, ...],
is_initialized: tuple[bool, ...]): is_initialized: tuple[bool, ...]):

@ -1162,10 +1162,8 @@ class JaxTestCase(parameterized.TestCase):
_compilation_cache_exit_stack: ExitStack | None = None _compilation_cache_exit_stack: ExitStack | None = None
# TODO(mattjj): this obscures the error messages from failures, figure out how def tearDown(self) -> None:
# to re-enable it assert core.reset_trace_state()
# def tearDown(self) -> None:
# assert core.reset_trace_state()
def setUp(self): def setUp(self):
super().setUp() super().setUp()

@ -19,7 +19,9 @@ from jax._src.core import (
AbstractToken as AbstractToken, AbstractToken as AbstractToken,
AbstractValue as AbstractValue, AbstractValue as AbstractValue,
Atom as Atom, Atom as Atom,
axis_frame as axis_frame,
AxisSize as AxisSize, AxisSize as AxisSize,
AxisName as AxisName,
CallPrimitive as CallPrimitive, CallPrimitive as CallPrimitive,
ClosedJaxpr as ClosedJaxpr, ClosedJaxpr as ClosedJaxpr,
ConcreteArray as ConcreteArray, ConcreteArray as ConcreteArray,
@ -40,36 +42,28 @@ from jax._src.core import (
JaxprPpSettings as JaxprPpSettings, JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError, JaxprTypeError as JaxprTypeError,
Literal as Literal, Literal as Literal,
MainTrace as MainTrace,
MapPrimitive as MapPrimitive, MapPrimitive as MapPrimitive,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
OpaqueTraceState as OpaqueTraceState, OpaqueTraceState as OpaqueTraceState,
NameGatheringSubst as NameGatheringSubst,
OutDBIdx as OutDBIdx, OutDBIdx as OutDBIdx,
OutputType as OutputType, OutputType as OutputType,
ParamDict as ParamDict, ParamDict as ParamDict,
Primitive as Primitive, Primitive as Primitive,
ShapedArray as ShapedArray, ShapedArray as ShapedArray,
Sublevel as Sublevel,
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
ThreadLocalState as ThreadLocalState,
Token as Token, Token as Token,
Trace as Trace, Trace as Trace,
TraceStack as TraceStack,
TraceState as TraceState,
Tracer as Tracer, Tracer as Tracer,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401
UnshapedArray as UnshapedArray, UnshapedArray as UnshapedArray,
Value as Value, Value as Value,
Var as Var, Var as Var,
abstract_token as abstract_token, abstract_token as abstract_token,
apply_todos as apply_todos,
aval_mapping_handlers as aval_mapping_handlers, aval_mapping_handlers as aval_mapping_handlers,
axis_frame as axis_frame,
call as call, call as call,
call_bind_with_continuation as call_bind_with_continuation,
call_impl as call_impl, call_impl as call_impl,
call_p as call_p, call_p as call_p,
check_jaxpr as check_jaxpr, check_jaxpr as check_jaxpr,
@ -77,15 +71,12 @@ from jax._src.core import (
concrete_aval as concrete_aval, concrete_aval as concrete_aval,
concrete_or_error as concrete_or_error, concrete_or_error as concrete_or_error,
concretization_function_error as concretization_function_error, concretization_function_error as concretization_function_error,
cur_sublevel as cur_sublevel,
custom_typechecks as custom_typechecks, custom_typechecks as custom_typechecks,
dedup_referents as dedup_referents, dedup_referents as dedup_referents,
do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
ensure_compile_time_eval as ensure_compile_time_eval, ensure_compile_time_eval as ensure_compile_time_eval,
escaped_tracer_error as escaped_tracer_error, escaped_tracer_error as escaped_tracer_error,
eval_context as eval_context, eval_context as eval_context,
eval_jaxpr as eval_jaxpr, eval_jaxpr as eval_jaxpr,
extend_axis_env as extend_axis_env,
extend_axis_env_nd as extend_axis_env_nd, extend_axis_env_nd as extend_axis_env_nd,
find_top_trace as find_top_trace, find_top_trace as find_top_trace,
full_lower as full_lower, full_lower as full_lower,
@ -102,44 +93,33 @@ from jax._src.core import (
lattice_join as lattice_join, lattice_join as lattice_join,
leaked_tracer_error as leaked_tracer_error, leaked_tracer_error as leaked_tracer_error,
literalable_types as literalable_types, literalable_types as literalable_types,
map_bind as map_bind,
map_bind_with_continuation as map_bind_with_continuation,
mapped_aval as mapped_aval, mapped_aval as mapped_aval,
maybe_find_leaked_tracers as maybe_find_leaked_tracers, maybe_find_leaked_tracers as maybe_find_leaked_tracers,
max_dim as max_dim, max_dim as max_dim,
min_dim as min_dim, min_dim as min_dim,
new_base_main as new_base_main,
new_jaxpr_eqn as new_jaxpr_eqn, new_jaxpr_eqn as new_jaxpr_eqn,
new_main as new_main,
new_sublevel as new_sublevel,
no_axis_name as no_axis_name, no_axis_name as no_axis_name,
no_effects as no_effects, no_effects as no_effects,
outfeed_primitives as outfeed_primitives, outfeed_primitives as outfeed_primitives,
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
primitive_uses_outfeed as primitive_uses_outfeed, primitive_uses_outfeed as primitive_uses_outfeed,
process_env_traces_call as process_env_traces_call,
process_env_traces_map as process_env_traces_map,
pytype_aval_mappings as pytype_aval_mappings, pytype_aval_mappings as pytype_aval_mappings,
raise_as_much_as_possible as raise_as_much_as_possible,
raise_to_shaped as raise_to_shaped, raise_to_shaped as raise_to_shaped,
raise_to_shaped_mappings as raise_to_shaped_mappings, raise_to_shaped_mappings as raise_to_shaped_mappings,
reset_trace_state as reset_trace_state, reset_trace_state as reset_trace_state,
stash_axis_env as stash_axis_env, set_current_trace as set_current_trace,
str_eqn_compact as str_eqn_compact, str_eqn_compact as str_eqn_compact,
subjaxprs as subjaxprs, subjaxprs as subjaxprs,
subst_axis_names as subst_axis_names,
subst_axis_names_eqn as subst_axis_names_eqn,
subst_axis_names_jaxpr as subst_axis_names_jaxpr,
subst_axis_names_var as subst_axis_names_var,
substitute_vars_in_output_ty as substitute_vars_in_output_ty, substitute_vars_in_output_ty as substitute_vars_in_output_ty,
thread_local_state as thread_local_state, take_current_trace as take_current_trace,
trace_ctx as trace_ctx,
trace_state_clean as trace_state_clean, trace_state_clean as trace_state_clean,
TraceTag as TraceTag,
traverse_jaxpr_params as traverse_jaxpr_params, traverse_jaxpr_params as traverse_jaxpr_params,
typecheck as typecheck, typecheck as typecheck,
typecompat as typecompat, typecompat as typecompat,
typematch as typematch, typematch as typematch,
unmapped_aval as unmapped_aval, unmapped_aval as unmapped_aval,
used_axis_names as used_axis_names,
used_axis_names_jaxpr as used_axis_names_jaxpr, used_axis_names_jaxpr as used_axis_names_jaxpr,
valid_jaxtype as valid_jaxtype, valid_jaxtype as valid_jaxtype,
) )

@ -14,18 +14,20 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager
from typing import Any from typing import Any
from jax._src import core from jax._src import core
from jax._src import source_info_util
from jax._src import api_util from jax._src import api_util
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src.ad_util import (Zero)
from jax._src.api_util import flatten_fun_nokwargs from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure,
treedef_tuple) treedef_tuple)
from jax._src.util import unzip2, safe_map, safe_zip, split_list from jax._src.util import unzip2, safe_map, safe_zip, split_list
from jax._src.dtypes import dtype, float0
map, unsafe_map = safe_map, map map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip zip, unsafe_zip = safe_zip, zip
@ -35,23 +37,13 @@ Pytree = Any
register = api_util.register_class_with_attrs register = api_util.register_class_with_attrs
@contextmanager
def top_trace():
stack = core.thread_local_state.trace_state.trace_stack.stack
main = stack.pop()
try:
trace = main.with_cur_sublevel()
yield trace
finally:
stack.append(main)
def jax_getattr(obj: Any, attr: str): def jax_getattr(obj: Any, attr: str):
with top_trace() as trace: with core.take_current_trace() as t:
return trace.process_getattr(obj, attr) return t.process_getattr(obj, attr)
def jax_setattr(obj: Any, attr: str, val: Pytree): def jax_setattr(obj: Any, attr: str, val: Pytree):
with top_trace() as trace: with core.take_current_trace() as t:
return trace.process_setattr(obj, attr, val) return t.process_setattr(obj, attr, val)
def _getattr_impl(_, obj, attr): def _getattr_impl(_, obj, attr):
return getattr(obj, attr) return getattr(obj, attr)
@ -62,7 +54,7 @@ def _setattr_impl(_, obj, attr, val):
core.EvalTrace.process_setattr = _setattr_impl core.EvalTrace.process_setattr = _setattr_impl
def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
frame = trace.main.jaxpr_stack[-1] # type: ignore frame = trace.frame
def new_tracer(x): def new_tracer(x):
aval = core.raise_to_shaped(core.get_aval(x)) aval = core.raise_to_shaped(core.get_aval(x))
@ -116,37 +108,40 @@ def _jvp(fun: lu.WrappedFun):
@lu.transformation @lu.transformation
def jvpfun2(primals, tangents): def jvpfun2(primals, tangents):
with core.new_main(ad.JVPTrace) as main: tag = core.TraceTag()
out_primals, out_tangents, tangent_attrs_out = \ tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
yield (main, primals, tangents), {} and dtype(t) == float0 else t for t in tangents]
del main ctx = source_info_util.transform_name_stack('jvp')
with ctx:
out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {}
yield out_primals, out_tangents, tangent_attrs_out yield out_primals, out_tangents, tangent_attrs_out
@lu.transformation @lu.transformation
def jvp_subtrace2(main, primals, tangents): def jvp_subtrace2(tag, primals, tangents):
main.attrs_tracked = [] # attrs written to with core.take_current_trace() as parent_trace:
trace = main.with_cur_sublevel() trace = ad.JVPTrace(parent_trace, tag)
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x tag.attrs_tracked = [] # attrs written to
for x, t in zip(primals, tangents)] in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
ans = yield in_tracers, {} for x, t in zip(primals, tangents)]
out_tracers = map(trace.full_raise, ans) with core.set_current_trace(trace):
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in out_tracers) ans = yield in_tracers, {}
tangent_attrs_out = [] out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
for (obj, name) in main.attrs_tracked: tangent_attrs_out = []
tracer = trace.full_raise(jax_getattr(obj, name)) for (obj, name) in tag.attrs_tracked:
jax_setattr(obj, name, tracer.primal) primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name))
if type(tracer.tangent) is not ad.Zero: jax_setattr(obj, name, primal)
tangent_attrs_out.append((obj, name, tracer.tangent)) if type(tangent) is not ad.Zero:
del main.attrs_tracked tangent_attrs_out.append((obj, name, tangent))
yield out_primals, out_tangents, tangent_attrs_out del tag.attrs_tracked
yield out_primals, out_tangents, tangent_attrs_out
def _setattr_jvp(trace, obj, attr, maybe_tracer): def _setattr_jvp(trace, obj, attr, maybe_tracer):
tracer = trace.full_raise(maybe_tracer) primal, tangent = trace.to_primal_tangent_pair(maybe_tracer)
if isinstance(tracer.tangent, ad.Zero): if isinstance(tangent, ad.Zero):
return setattr(obj, attr, tracer.primal) return setattr(obj, attr, primal)
if (obj, attr) not in trace.main.attrs_tracked: if (obj, attr) not in trace.tag.attrs_tracked:
trace.main.attrs_tracked.append((obj, attr)) trace.tag.attrs_tracked.append((obj, attr))
return setattr(obj, attr, tracer) return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent))
ad.JVPTrace.process_setattr = _setattr_jvp ad.JVPTrace.process_setattr = _setattr_jvp
def _getattr_jvp(trace, obj, attr): def _getattr_jvp(trace, obj, attr):

@ -399,7 +399,7 @@ def convert(fun_jax: Callable,
# It is Ok to nest convert when we are inside a call_tf # It is Ok to nest convert when we are inside a call_tf
raise ValueError( raise ValueError(
"convert must be used outside all JAX transformations." + "convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}") f"Trace state: {core.trace_ctx}")
global _has_registered_tf_source_path global _has_registered_tf_source_path
if not _has_registered_tf_source_path: if not _has_registered_tf_source_path:
@ -844,15 +844,11 @@ def _interpret_fun_jax(
extra_name_stack: str | None, extra_name_stack: str | None,
fresh_constant_cache: bool = False, fresh_constant_cache: bool = False,
) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: ) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]:
with core.new_base_main(TensorFlowTrace) as main: subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals)
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) with _extended_name_stack(extra_name_stack):
with _extended_name_stack(extra_name_stack): out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
with core.new_sublevel(): _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ fresh_constant_cache=fresh_constant_cache)
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
fresh_constant_cache=fresh_constant_cache)
del main
return util.unzip2(out_vals) return util.unzip2(out_vals)
@ -1036,16 +1032,16 @@ def _convert_jax_impl(impl_jax: Callable, *,
@lu.transformation @lu.transformation
def _interpret_subtrace(main: core.MainTrace, def _interpret_subtrace(in_avals: Sequence[core.ShapedArray],
in_avals: Sequence[core.ShapedArray],
*in_vals: TfVal): *in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel()) trace = TensorFlowTrace()
in_tracers = tuple( in_tracers = tuple(
TensorFlowTracer(trace, val, aval) TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals)) for val, aval in zip(in_vals, in_avals))
outs = yield in_tracers, {} # type: Sequence[TfVal] with core.set_current_trace(trace):
outs = yield in_tracers, {} # type: Sequence[TfVal]
out_tracers: Iterable[TensorFlowTracer] = ( out_tracers: Iterable[TensorFlowTracer] = (
map(trace.full_raise, outs)) map(trace.to_tf_tracer, outs))
out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = (
tuple((t.val, t.aval) for t in out_tracers)) tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals yield out_vals_with_avals
@ -1321,13 +1317,14 @@ class TensorFlowTrace(core.Trace):
those will introduce their own MainTrace, and any operations involving those those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT. will be done on those traces, i.e., not a concern for TFT.
""" """
def pure(self, val: TfVal) -> TensorFlowTracer: def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer:
"""Lifts a non-Tracer into the TensorFlowTracer. """Lifts a non-Tracer into the TensorFlowTracer.
This function may be called by way of trace.full_raise.
""" """
if isinstance(val, TensorFlowTracer):
return val
if hasattr(val, "__jax_array__"): if hasattr(val, "__jax_array__"):
val = val.__jax_array__() with core.set_current_trace(self):
val = val.__jax_array__()
if isinstance(val, TensorFlowTracer): if isinstance(val, TensorFlowTracer):
return val return val
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True) tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True)
@ -1335,20 +1332,10 @@ class TensorFlowTrace(core.Trace):
self, tf_val, core.ShapedArray(np.shape(val), jax_dtype, self, tf_val, core.ShapedArray(np.shape(val), jax_dtype,
weak_type=dtypes.is_weakly_typed(val))) weak_type=dtypes.is_weakly_typed(val)))
def lift(self, val: core.Tracer) -> TensorFlowTracer:
# This would be called when we need to raise a tracer from a lower-level
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
# inside another transform, there are no lower-level main traces.
assert False
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
# This is called when we need to raise a tracer from the same main,
# but a lower sublevel. This could come from a nested jit.
return TensorFlowTracer(self, val.val, val._aval)
def process_primitive(self, primitive: core.Primitive, def process_primitive(self, primitive: core.Primitive,
tracers: Sequence[TensorFlowTracer], tracers: Sequence[TensorFlowTracer],
params) -> TensorFlowTracer: params) -> TensorFlowTracer:
tracers = map(self.to_tf_tracer, tracers)
impl, impl_needs_avals = self.get_primitive_impl(primitive) impl, impl_needs_avals = self.get_primitive_impl(primitive)
args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
# This is a bit conservative, doing abstract_eval even in op-by-op execution # This is a bit conservative, doing abstract_eval even in op-by-op execution
@ -1424,39 +1411,18 @@ class TensorFlowTrace(core.Trace):
def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params): tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results assert call_primitive.multiple_results
tracers = map(self.to_tf_tracer, tracers)
vals: Sequence[TfVal] = [t.val for t in tracers] vals: Sequence[TfVal] = [t.val for t in tracers]
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
interpreted_fun = _interpret_subtrace(fun, self.main, avals) interpreted_fun = _interpret_subtrace(fun, avals)
extra_name_stack = None extra_name_stack = None
with _extended_name_stack(extra_name_stack): with _extended_name_stack(extra_name_stack):
with core.new_sublevel(): vals_out = interpreted_fun.call_wrapped(*vals)
vals_out = interpreted_fun.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out] return [TensorFlowTracer(self, v, a) for v, a in vals_out]
def post_process_call(self, call_primitive: core.Primitive,
out_tracers: Sequence[TensorFlowTracer], params):
# We encountered a call primitive whose result (out_tracers) include
# TensorFlowTracer that were not passed through its arguments (captured from
# the environment).
vals = tuple(t.val for t in out_tracers)
main = self.main
def todo(vals: Sequence[TfVal]):
# TODO: is name_stack correct?
trace = TensorFlowTrace(main, core.cur_sublevel())
return [
TensorFlowTracer(trace, v, out_tracer.aval)
for v, out_tracer in zip(vals, out_tracers)
]
return vals, todo
def process_map(self, map_primitive, f, tracers, params): def process_map(self, map_primitive, f, tracers, params):
raise NotImplementedError("process_map") raise NotImplementedError("process_map")
def post_process_map(self, map_primitive, out_tracers, params):
raise NotImplementedError("post_process_map")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Drop the custom differentiation rule and act like a call primitive. This # Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so # behavior is desirable because jax2tf stages code out of the JAX system, so
@ -1464,9 +1430,6 @@ class TensorFlowTrace(core.Trace):
del jvp, symbolic_zeros # Unused. del jvp, symbolic_zeros # Unused.
return self.process_call(core.call_p, fun, tracers, {}) return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_jvp_call(self, out_tracers, _):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros): symbolic_zeros):
# Drop the custom differentiation rule and act like a call primitive. This # Drop the custom differentiation rule and act like a call primitive. This
@ -1475,12 +1438,6 @@ class TensorFlowTrace(core.Trace):
del fwd, bwd, out_trees, symbolic_zeros # Unused. del fwd, bwd, out_trees, symbolic_zeros # Unused.
return self.process_call(core.call_p, fun, tracers, {}) return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable assuming jax2tf runs with clean trace state
def post_process_custom_vjp_call_fwd(self, *_, **__):
assert False # unreachable assuming jax2tf runs with clean trace state
def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]:
# Returns the primitive implementation and whether the implementation # Returns the primitive implementation and whether the implementation
# takes abstract values (see definition of tf_impl_with_avals) # takes abstract values (see definition of tf_impl_with_avals)

@ -152,22 +152,22 @@ def jet(fun, primals, series):
@lu.transformation @lu.transformation
def jet_fun(order, primals, series): def jet_fun(order, primals, series):
with core.new_main(JetTrace) as main: tag = core.TraceTag()
main.order = order out_primals, out_terms = yield (tag, order, primals, series), {}
out_primals, out_terms = yield (main, primals, series), {}
del main
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
for p, s in zip(out_primals, out_terms)] for p, s in zip(out_primals, out_terms)]
yield out_primals, out_terms yield out_primals, out_terms
@lu.transformation @lu.transformation
def jet_subtrace(main, primals, series): def jet_subtrace(tag, order, primals, series):
trace = JetTrace(main, core.cur_sublevel()) with core.take_current_trace() as parent_trace:
in_tracers = map(partial(JetTracer, trace), primals, series) trace = JetTrace(tag, parent_trace, order)
ans = yield in_tracers, {} in_tracers = map(partial(JetTracer, trace), primals, series)
out_tracers = map(trace.full_raise, ans) with core.set_current_trace(trace):
out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) ans = yield in_tracers, {}
yield out_primals, out_terms
out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans))
yield out_primals, out_terms
@lu.transformation_with_aux @lu.transformation_with_aux
def traceable(in_tree_def, *primals_and_series): def traceable(in_tree_def, *primals_and_series):
@ -198,33 +198,44 @@ class JetTracer(core.Tracer):
class JetTrace(core.Trace): class JetTrace(core.Trace):
def pure(self, val): def __init__(self, tag, parent_trace, order):
return JetTracer(self, val, zero_series) self.tag = tag
self.parent_trace = parent_trace
self.order = order
def lift(self, val): def to_primal_terms_pair(self, val):
return JetTracer(self, val, zero_series) if isinstance(val, JetTracer) and val._trace.tag is self.tag:
return val.primal, val.terms
def sublift(self, val): else:
return JetTracer(self, val.primal, val.terms) return val, zero_series
def process_primitive(self, primitive, tracers, params): def process_primitive(self, primitive, tracers, params):
order = self.main.order # pytype: disable=attribute-error order = self.order # pytype: disable=attribute-error
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers))
if all(t is zero_series for t in series_in):
primal_out = primitive.bind_with_trace(self.parent_trace, primals_in, params)
if primitive.multiple_results:
return [JetTracer(self, p, zero_series) for p in primal_out]
else:
return JetTracer(self, primal_out, zero_series)
series_in = [[zero_term] * order if s is zero_series else s series_in = [[zero_term] * order if s is zero_series else s
for s in series_in] for s in series_in]
# TODO(mattjj): avoid always instantiating zeros with core.set_current_trace(self.parent_trace):
series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x)) # TODO(mattjj): avoid always instantiating zeros
if t is zero_term else t for t in series] series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x))
for x, series in zip(primals_in, series_in)] if t is zero_term else t for t in series]
rule = jet_rules[primitive] for x, series in zip(primals_in, series_in)]
primal_out, terms_out = rule(primals_in, series_in, **params) rule = jet_rules[primitive]
primal_out, terms_out = rule(primals_in, series_in, **params)
if not primitive.multiple_results: if not primitive.multiple_results:
return JetTracer(self, primal_out, terms_out) return JetTracer(self, primal_out, terms_out)
else: else:
return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)] return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]
def process_call(self, call_primitive, f, tracers, params): def process_call(self, call_primitive, f, tracers, params):
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers))
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
update_params = call_param_updaters.get(call_primitive) update_params = call_param_updaters.get(call_primitive)
@ -234,17 +245,6 @@ class JetTrace(core.Trace):
primals_out, series_out = tree_unflatten(out_tree_def(), result) primals_out, series_out = tree_unflatten(out_tree_def(), result)
return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)] return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
out, treedef = tree_flatten((primals, series))
del primals, series
main = self.main
def todo(x):
primals, series = tree_unflatten(treedef, x)
trace = JetTrace(main, core.cur_sublevel())
return map(partial(JetTracer, trace), primals, series)
return out, todo
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros): symbolic_zeros):
# TODO(mattjj): don't just ignore custom jvp rules? # TODO(mattjj): don't just ignore custom jvp rules?

@ -359,22 +359,18 @@ ad.deflinear2(host_local_array_to_global_array_p,
lambda ct, _, **params: ( lambda ct, _, **params: (
host_local_array_to_global_array_p.bind(ct, **params),)) host_local_array_to_global_array_p.bind(ct, **params),))
def ltg_batcher(insert_axis, spmd_axis_name, axis_size, def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec):
axis_name, main_type, vals_in, dims_in,
global_mesh, pspec):
x, = vals_in x, = vals_in
d, = dims_in d, = dims_in
new_parts = None if spmd_axis_name is None else spmd_axis_name new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name
new_pspec = list(pspec) new_pspec = list(pspec)
new_pspec.insert(d, new_parts) new_pspec.insert(d, new_parts)
new_pspec = P(*new_pspec) new_pspec = P(*new_pspec)
y = host_local_array_to_global_array_p.bind( y = host_local_array_to_global_array_p.bind(
x, global_mesh=global_mesh, pspec=new_pspec) x, global_mesh=global_mesh, pspec=new_pspec)
return y, d return y, d
batching.spmd_axis_primitive_batchers[host_local_array_to_global_array_p] = partial( batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial(
ltg_batcher, False) ltg_batcher, False)
batching.axis_primitive_batchers[host_local_array_to_global_array_p] = partial(
ltg_batcher, False, None)
def _ltg_lowering(ctx, x, *, global_mesh, pspec): def _ltg_lowering(ctx, x, *, global_mesh, pspec):
return [x] return [x]

@ -53,9 +53,9 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
special, control_flow, ann) special, control_flow, ann)
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list, as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2) split_list, subs_list2)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -454,30 +454,9 @@ MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive): class ShardMapPrimitive(core.Primitive):
multiple_results = True multiple_results = True
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh, def bind_with_trace(self, trace, fun_and_args, params):
in_names: tuple[AxisNames, ...], fun, *args = fun_and_args
out_names_thunk: Callable[[], tuple[AxisNames, ...]], return trace.process_shard_map(shard_map_p, fun, args, **params)
check_rep: bool, rewrite: bool, auto: frozenset[AxisName]
) -> Sequence[MaybeTracer]:
top_trace = core.find_top_trace(args)
fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto)
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_names = out_names_thunk()
_, xforms = env_todo()
for t in xforms:
out_names = t(out_names)
return out_names
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
todos, _ = env_todo()
return map(core.full_lower, core.apply_todos(todos, outs))
def get_bind_params(self, params): def get_bind_params(self, params):
new_params = dict(params) new_params = dict(params)
@ -489,56 +468,37 @@ class ShardMapPrimitive(core.Primitive):
shard_map_p = ShardMapPrimitive('shard_map') shard_map_p = ShardMapPrimitive('shard_map')
@lu.transformation_with_aux
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
rewrite, auto, *args: Any):
outs = yield args, {}
todos, out_names_transforms = [], []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=op.attrgetter('_trace.level'))
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, (todo, xform) = trace.post_process_shard_map(
outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto)
todos.append(todo)
out_names_transforms.append(xform)
yield outs, (tuple(todos), tuple(out_names_transforms))
# Staging # Staging
def _shard_map_staging( def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, in_tracers: Sequence[Any], *, mesh: Mesh,
in_names: tuple[AxisNames, ...], in_names: tuple[AxisNames, ...],
out_names_thunk: Callable[[], tuple[AxisNames, ...]], out_names_thunk: Callable[[], tuple[AxisNames, ...]],
check_rep: bool, check_rep: bool,
rewrite: bool, rewrite: bool,
auto: frozenset, auto: frozenset,
) -> Sequence[pe.DynamicJaxprTracer]: ) -> Sequence[pe.DynamicJaxprTracer]:
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
in_avals = [t.aval for t in in_tracers] in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main with core.extend_axis_env_nd(list(mesh.shape.items())):
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = map(_check_shapedarray, genavals)
_check_names(out_names_thunk(), out_avals_) _check_names(out_names_thunk(), out_avals_)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
if check_rep: if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _check_rep(mesh, jaxpr, in_rep) out_rep = _check_rep(mesh, jaxpr, in_rep)
_check_reps(mesh, out_names_thunk(), out_rep) _check_reps(mesh, out_names_thunk(), out_rep)
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_) out_avals = map(_check_shapedarray, out_avals_)
out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval))
for names, aval in zip(out_names_thunk(), out_avals)]
source_info = source_info_util.current() source_info = source_info_util.current()
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
invars = map(trace.getvar, in_tracers) invars = map(trace.getvar, in_tracers)
constvars = map(trace.getvar, map(trace.instantiate_const, consts)) constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
outvars = map(trace.makevar, out_tracers) outvars = map(trace.makevar, out_tracers)
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
with core.extend_axis_env_nd(mesh.shape.items()): with core.extend_axis_env_nd(list(mesh.shape.items())):
jaxpr = pe.convert_constvars_jaxpr(jaxpr) jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged, params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr, out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
@ -804,28 +764,23 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
mesh = get_mesh_from_args(args, mesh) mesh = get_mesh_from_args(args, mesh)
args = map(partial(_unmatch_spec, mesh), in_names, args) args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names) in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: outs, out_rep = _run_shmap(fun, mesh, args, in_rep, check_rep)
fun, out_rep = _shmap_subtrace(fun, main, in_rep)
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main):
outs = fun.call_wrapped(*args)
del main
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs]
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_rep: if check_rep:
_check_reps(mesh, out_names_thunk(), out_rep()) _check_reps(mesh, out_names_thunk(), out_rep)
pspecs = map(_names_to_pspec, out_names_thunk()) pspecs = map(_names_to_pspec, out_names_thunk())
return map(partial(_match_spec, mesh, check_rep), pspecs, outs) return map(partial(_match_spec, mesh, check_rep), pspecs, outs)
core.EvalTrace.process_shard_map = _shard_map_impl core.EvalTrace.process_shard_map = _shard_map_impl
@lu.transformation_with_aux def _run_shmap(f, mesh, args, reps, check_rep):
def _shmap_subtrace(main, in_rep, *in_vals): trace = ShardMapTrace(mesh, check_rep)
t = main.with_cur_sublevel() in_tracers = map(partial(ShardMapTracer, trace), reps, args)
in_tracers = map(partial(ShardMapTracer, t), in_rep, in_vals) with core.set_current_trace(trace):
ans = yield in_tracers, {} with core.extend_axis_env_nd(mesh.shape.items()):
out_tracers = map(t.full_raise, ans) ans = f.call_wrapped(*in_tracers)
outs, out_rep = unzip2((t.val, t.rep) for t in out_tracers) outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
del t, in_tracers, ans, out_tracers return outs, out_rep
yield outs, out_rep
def _names_to_pspec(names: AxisNames) -> PartitionSpec: def _names_to_pspec(names: AxisNames) -> PartitionSpec:
ndmin = max(names) + 1 if names else 0 ndmin = max(names) + 1 if names else 0
@ -877,20 +832,21 @@ class ShardMapTrace(core.Trace):
mesh: Mesh mesh: Mesh
check: bool check: bool
def __init__(self, *args, mesh, check): def __init__(self, mesh, check):
super().__init__(*args)
self.mesh = mesh self.mesh = mesh
self.check = check self.check = check
def pure(self, val): def to_val_rep_pair(self, val):
val_ = _unmatch_spec(self.mesh, {}, val) if isinstance(val, ShardMapTracer):
return ShardMapTracer(self, None, val_) return val.val, val.rep
elif isinstance(val, Tracer):
def sublift(self, tracer): raise Exception("Shouldn't have any non-shard_map tracers")
return ShardMapTracer(self, tracer.rep, tracer.val) else:
val_ = _unmatch_spec(self.mesh, {}, val)
return val_, None
def process_primitive(self, prim, tracers, params): def process_primitive(self, prim, tracers, params):
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
eager_rule = eager_rules.get(prim) eager_rule = eager_rules.get(prim)
if eager_rule: if eager_rule:
out_vals = eager_rule(self.mesh, *in_vals, **params) out_vals = eager_rule(self.mesh, *in_vals, **params)
@ -926,36 +882,21 @@ class ShardMapTrace(core.Trace):
"https://github.com/jax-ml/jax/issues") "https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros del prim, jvp, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check)
with core.new_sublevel(): return map(partial(ShardMapTracer, self), out_rep, out_vals)
out_vals = fun.call_wrapped(*in_vals)
return map(partial(ShardMapTracer, self), out_rep(), out_vals)
def post_process_custom_jvp_call(self, out_tracers, _):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros): symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros: if symbolic_zeros:
msg = ("custom_vjp symbolic_zeros support with shard_map is not " msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at " "implemented; please open an issue at "
"https://github.com/jax-ml/jax/issues") "https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check)
with core.new_sublevel(): return map(partial(ShardMapTracer, self), out_rep, out_vals)
out_vals = fun.call_wrapped(*in_vals)
return map(partial(ShardMapTracer, self), out_rep(), out_vals)
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable
def process_axis_index(self, frame):
with core.eval_context(), jax.disable_jit(False):
return jax.jit(lambda: jax.lax.axis_index(frame.name))()
class ShardMapTracer(core.Tracer): class ShardMapTracer(core.Tracer):
@ -978,9 +919,6 @@ class ShardMapTracer(core.Tracer):
aval = core.raise_to_shaped(aval) aval = core.raise_to_shaped(aval)
return core.mapped_aval(self._trace.mesh.size, 0, aval) return core.mapped_aval(self._trace.mesh.size, 0, aval)
def full_lower(self) -> ShardMapTracer:
return self
def __str__(self) -> str: def __str__(self) -> str:
with core.eval_context(): with core.eval_context():
blocks = list(self.val) blocks = list(self.val)
@ -1023,17 +961,16 @@ eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# New primitives for efficient transposition # New primitives for efficient transposition
# psum2_p is like psum_p except has a different transpose, so mostly copied: # psum2_p is like psum_p except has a different transpose, so mostly copied:
psum2_p = core.AxisPrimitive('psum2') psum2_p = core.Primitive('psum2')
psum2_p.multiple_results = True psum2_p.multiple_results = True
psum2_p.def_impl(lax_parallel.psum_p.impl) psum2_p.def_impl(lax_parallel.psum_p.impl)
psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval)
mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p])
batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p) batching.fancy_primitive_batchers[psum2_p] = \
batching.axis_primitive_batchers[psum2_p] = \
partial(lax_parallel._batched_reduction_collective, psum2_p, partial(lax_parallel._batched_reduction_collective, psum2_p,
lambda v, axis_size: axis_size * v) lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum2_p] = \ batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes')
partial(lax_parallel._subst_all_names_in_param, 'axes')
def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
del args del args
return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
@ -1046,7 +983,7 @@ def pbroadcast(x, axis_name):
xs, treedef = tree_flatten(x) xs, treedef = tree_flatten(x)
ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None)
return tree_unflatten(treedef, ys) return tree_unflatten(treedef, ys)
pbroadcast_p = core.AxisPrimitive('pbroadcast') pbroadcast_p = core.Primitive('pbroadcast')
pbroadcast_p.multiple_results = True pbroadcast_p.multiple_results = True
pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args)
pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args)
@ -1057,12 +994,6 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups):
axis_index_groups=axis_index_groups) axis_index_groups=axis_index_groups)
return vals_out, dims_in return vals_out, dims_in
batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes,
groups):
raise NotImplementedError # vmap with axis name involved in this primitive
batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher
core.axis_substitution_rules[pbroadcast_p] = \
partial(lax_parallel._subst_all_names_in_param, 'axes')
ad.deflinear2(pbroadcast_p, ad.deflinear2(pbroadcast_p,
lambda cts, *_, axes, axis_index_groups: lambda cts, *_, axes, axis_index_groups:
psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups))
@ -1421,23 +1352,23 @@ def _shard_map_batch(
check_rep: bool, check_rep: bool,
rewrite: bool, rewrite: bool,
auto: frozenset) -> Sequence[batching.BatchTracer]: auto: frozenset) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers) in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
if all(bdim is batching.not_mapped for bdim in in_dims):
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
out_names_thunk=out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
if any(isinstance(d, batching.RaggedAxis) for d in in_dims): if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
raise NotImplementedError raise NotImplementedError
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore
for ax in names} for names, d in zip(in_names, in_dims)] for ax in names} for names, d in zip(in_names, in_dims)]
spmd_axis_name = trace.spmd_axis_name spmd_axis_name = trace.axis_data.spmd_name
if spmd_axis_name is not None: if spmd_axis_name is not None:
used = {n for names in in_names for ns in names.values() for n in ns} used = {n for names in in_names for ns in names.values() for n in ns}
if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(new_in_names, in_dims)] else ns for ns, d in zip(new_in_names, in_dims)]
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name)
new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name)
else:
new_axis_data = trace.axis_data
fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims))
@as_hashable_function(closure=out_names_thunk) @as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk(): def new_out_names_thunk():
return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk())
@ -1445,25 +1376,13 @@ def _shard_map_batch(
new_params = dict(mesh=mesh, in_names=new_in_names, new_params = dict(mesh=mesh, in_names=new_in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep, out_names_thunk=new_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto) rewrite=rewrite, auto=auto)
out_vals = prim.bind(fun, *in_vals, **new_params) with core.set_current_trace(trace.parent_trace):
out_vals = prim.bind(fun, *in_vals, **new_params)
make_tracer = partial(batching.BatchTracer, trace, make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current()) source_info=source_info_util.current())
return map(make_tracer, out_vals, out_dims()) return map(make_tracer, out_vals, out_dims())
batching.BatchTrace.process_shard_map = _shard_map_batch batching.BatchTrace.process_shard_map = _shard_map_batch
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto):
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
m = trace.main
def todo(vals):
trace = m.with_cur_sublevel()
return map(partial(batching.BatchTracer, trace), vals, dims, srcs)
out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims)
return vals, (todo, out_names_transform)
batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process
def _batch_out_names(spmd_axis_name, dims, out_names): def _batch_out_names(spmd_axis_name, dims, out_names):
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, dims)] for ax in names} for names, d in zip(out_names, dims)]
@ -1480,11 +1399,11 @@ def _batch_out_names(spmd_axis_name, dims, out_names):
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto): out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
which_nz = [ type(t) is not ad.Zero for t in tangents] which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents)) args, in_tree = tree_flatten((primals, tangents))
f_jvp = ad.jvp_subtrace(f, trace.main) f_jvp = ad.jvp_subtrace(f, trace.tag)
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz]
@ -1496,36 +1415,22 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep, out_names_thunk=new_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto) rewrite=rewrite, auto=auto)
f_jvp, out_tree = ad.traceable(f_jvp, in_tree) f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind(f_jvp, *args, **params) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params)
primal_out, tangent_out = tree_unflatten(out_tree(), result) primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t
for p, t in zip(primal_out, tangent_out)] for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto):
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
tangents_nz = [type(t) is not ad.Zero for t in tangents]
m = trace.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents)
def out_names_transform(out_names):
return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz))
return out, (todo, out_names_transform)
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto): out_names_thunk, check_rep, rewrite, auto):
tracers = map(trace.to_jaxpr_tracer, tracers)
in_pvals = [t.pval for t in tracers] in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
all_names = _all_mesh_names(mesh) all_names = _all_mesh_names_except_spmd(mesh, trace)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False)
f = _promote_scalar_residuals(f) f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits( f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,)) f, (*in_knowns,), (*in_avals_sharded,))
@ -1540,7 +1445,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
known_params = dict(mesh=mesh, in_names=(*known_in_names,), known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep, out_names_thunk=known_out_names, check_rep=check_rep,
rewrite=rewrite, auto=auto) rewrite=rewrite, auto=auto)
out = shard_map_p.bind(f_known, *in_consts, **known_params) out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params)
in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
@ -1553,7 +1458,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
const_tracers = map(trace.new_instantiated_const, res) const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env) env_tracers = map(trace.to_jaxpr_tracer, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()] unk_arg_tracers = [t for t in tracers if not t.is_known()]
unk_params = dict(mesh=mesh, in_names=unk_in_names, unk_params = dict(mesh=mesh, in_names=unk_in_names,
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, out_names=unk_out_names, jaxpr=jaxpr, check_rep=False,
@ -1569,55 +1474,6 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
return pe.merge_lists(out_knowns, out_tracers, out_consts) return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_partial_eval_post_process(
trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto):
del check_rep
all_names = _all_mesh_names(mesh)
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
# TODO(mattjj): output forwarding optimization
which = [not getattr(v.aval, 'shape', True) for v in jaxpr.constvars]
res = [jax.lax.broadcast(x, (1,)) if not getattr(v.aval, 'shape', True) else x
for x, v in zip(res, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
out = [*consts, *res]
main = trace.main
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr)
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res_ = split_list(out, [len(out) - len(res)])
const_tracers = map(trace.new_instantiated_const, res_)
env_tracers = map(trace.full_raise, env)
staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False,
rewrite=rewrite, auto=auto)
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
name_stack = trace._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
shard_map_p, staged_params, effs, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def out_names_transform(out_names):
nonlocal out_names_unknown
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
return (*out_names_known,) + ({0: all_names},) * len(res)
out_names_unknown: list | None = None
return out, (todo, out_names_transform)
pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
@lu.transformation @lu.transformation
def _promote_scalar_residuals(*args, **kwargs): def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
@ -1645,7 +1501,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over # We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case. # more chips than required in the transpose-no-check-rep case.
name_set = {n for ns in names.values() for n in ns} name_set = {n for ns in names.values() for n in ns}
return [n for n in _all_mesh_names(mesh) if n not in name_set] return [n for n in mesh.axis_names if n not in name_set]
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
@ -1692,18 +1548,6 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
return tree_unflatten(out_tree(), out_flat) return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose ad.primitive_transposes[shard_map_p] = _shard_map_transpose
def _shard_map_axis_subst(params, subst, traverse):
if 'jaxpr' not in params:
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
# Remat # Remat
def _partial_eval_jaxpr_custom_rule( def _partial_eval_jaxpr_custom_rule(
@ -1783,7 +1627,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
in_fwd, out_fwd, which, params_known, params_staged): in_fwd, out_fwd, which, params_known, params_staged):
# prune inputs to jaxpr_known according to unks_in # prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh'] mesh = params_known['mesh']
all_names = _all_mesh_names(mesh) all_names = _all_mesh_names_except_spmd(mesh)
in_names_known, _ = partition_list(unks_in, params_known['in_names']) in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: all_names}] * sum(which) out_names_known = out_names_known + [{0: all_names}] * sum(which)
@ -1801,15 +1645,13 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
out_names=tuple(out_names_staged), check_rep=False) out_names=tuple(out_names_staged), check_rep=False)
return new_params_known, new_params_staged, all_names return new_params_known, new_params_staged, all_names
# TODO(mattjj): remove this mechanism when we revise mesh scopes # TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]: def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
stack = core.thread_local_state.trace_state.trace_stack.stack trace = core.unsafe_get_current_trace() if trace is None else trace
names = {n for frame in stack stack = core.unsafe_get_trace_stack(trace)
if (ns := frame.payload.get('spmd_axis_name', ())) is not None batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)]
for n in ns} spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name }
return tuple(name for name in mesh.axis_names if name not in names) return tuple(name for name in mesh.axis_names if name not in spmd_names)
# DCE # DCE
@ -1926,59 +1768,52 @@ class RewriteTracer(core.Tracer):
def aval(self) -> core.AbstractValue: def aval(self) -> core.AbstractValue:
return core.get_aval(self.val) return core.get_aval(self.val)
def full_lower(self) -> RewriteTracer:
return self
def __str__(self) -> str: def __str__(self) -> str:
return str(self.val) # TODO(mattjj): could show replication info here return str(self.val) # TODO(mattjj): could show replication info here
__repr__ = __str__ # for debuggers, like `p x` __repr__ = __str__ # for debuggers, like `p x`
class RewriteTrace(core.Trace): class RewriteTrace(core.Trace):
parent_trace : core.Trace
tag : core.TraceTag
mesh: Mesh mesh: Mesh
dyna: int
def __init__(self, *args, mesh, dyna): def __init__(self, parent_trace, tag, mesh):
super().__init__(*args) self.parent_trace = parent_trace
self.tag = tag
self.mesh = mesh self.mesh = mesh
self.dyna = dyna
def pure(self, val) -> RewriteTracer: def to_val_rep_pair(self, val):
return RewriteTracer(self, set(self.mesh.axis_names), val) # TODO: add a tag to tell if self
if isinstance(val, RewriteTracer) and val._trace.tag is self.tag:
def lift(self, tracer: core.Tracer) -> RewriteTracer: return val.val, val.rep
return RewriteTracer(self, set(self.mesh.axis_names), tracer) else:
return val, set(self.mesh.axis_names)
def sublift(self, tracer: core.Tracer) -> RewriteTracer:
return RewriteTracer(self, tracer.rep, tracer.val)
def process_primitive(self, prim, in_tracers, params): def process_primitive(self, prim, in_tracers, params):
rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) rule = _rewrite_rules.get(prim, partial(_rule_missing, prim))
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers))
with core.new_dynamic(self.dyna): with core.set_current_trace(self.parent_trace):
out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals)
return out_tracers if prim.multiple_results else out_tracers[0] return out_tracers if prim.multiple_results else out_tracers[0]
def process_call(self, call_primitive, f, in_tracers, params): def process_call(self, call_primitive, f, in_tracers, params):
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers))
f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps)) f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps))
with core.new_dynamic(self.dyna): with core.set_current_trace(self.parent_trace):
out_vals = call_primitive.bind(f, *in_vals, **params) out_vals = call_primitive.bind(f, *in_vals, **params)
return map(partial(RewriteTracer, self), out_reps(), out_vals) return map(partial(RewriteTracer, self), out_reps(), out_vals)
def post_process_call(self, call_primitive, out_tracers, params):
assert False # unreachable
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros: if symbolic_zeros:
msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to " "as a temporary workaround pass the check_rep=False argument to "
"shard_map") "shard_map")
raise NotImplementedError(msg) raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers))
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps)
jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2) jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2)
with core.new_dynamic(self.dyna): with core.set_current_trace(self.parent_trace):
out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros)
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
if not fst: if not fst:
@ -1986,9 +1821,6 @@ class RewriteTrace(core.Trace):
out_reps = out_reps[:len(out_reps) // 2] out_reps = out_reps[:len(out_reps) // 2]
return map(partial(RewriteTracer, self), out_reps, out_vals) return map(partial(RewriteTracer, self), out_reps, out_vals)
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros): symbolic_zeros):
if symbolic_zeros: if symbolic_zeros:
@ -1996,12 +1828,12 @@ class RewriteTrace(core.Trace):
"as a temporary workaround pass the check_rep=False argument to " "as a temporary workaround pass the check_rep=False argument to "
"shard_map") "shard_map")
raise NotImplementedError(msg) raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers))
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps)
fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]]
fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps) fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps)
bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps)
with core.new_dynamic(self.dyna): with core.set_current_trace(self.parent_trace):
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros) symbolic_zeros=symbolic_zeros)
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
@ -2010,36 +1842,24 @@ class RewriteTrace(core.Trace):
_, out_reps = split_list(out_reps, [res_tree.num_leaves]) _, out_reps = split_list(out_reps, [res_tree.num_leaves])
return map(partial(RewriteTracer, self), out_reps, out_vals) return map(partial(RewriteTracer, self), out_reps, out_vals)
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable
# TODO process_axis_index
def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
in_reps = map(partial(_in_names_to_rep, mesh), in_names) in_reps = map(partial(_in_names_to_rep, mesh), in_names)
out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()]
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
return _match_rep(fun, mesh, out_reps_src, out_reps_dst) return _match_rep(fun, mesh, out_reps_src, out_reps_dst)
def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps):
return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps)
@lu.transformation_with_aux @lu.transformation_with_aux
def _efficient_transpose_outer(mesh, in_reps, *args): def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
lvl = core.dynamic_level() with core.take_current_trace() as parent:
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: tag = core.TraceTag()
out_vals, out_reps = yield (main, mesh, in_reps, args), {} t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh)
del main in_tracers = map(partial(RewriteTracer, t), in_reps, args)
with core.set_current_trace(t):
ans = yield in_tracers, {}
out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans))
del t, in_tracers, ans
yield out_vals, out_reps yield out_vals, out_reps
@lu.transformation
def _efficient_transpose_inner(main, mesh, in_reps, args):
t = main.with_cur_sublevel()
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
yield unzip2((t.val, t.rep) for t in out_tracers)
@lu.transformation @lu.transformation
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
outs = yield args, {} outs = yield args, {}
@ -2060,8 +1880,7 @@ def _replication_rewrite_match(
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
f = _match_rep(f, mesh, out_rep, out_rep_dst) f = _match_rep(f, mesh, out_rep, out_rep_dst)
with core.extend_axis_env_nd(mesh.shape.items()): jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts) return core.ClosedJaxpr(jaxpr_, consts)
# TODO(mattjj): caching # TODO(mattjj): caching
@ -2072,28 +1891,25 @@ def _replication_rewrite_nomatch(
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: ) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
with core.extend_axis_env_nd(mesh.shape.items()): jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), out_rep() return core.ClosedJaxpr(jaxpr_, consts), out_rep()
@lu.transformation_with_aux @lu.transformation_with_aux
def _rewrite_subtrace(main, in_reps, *in_vals): def _rewrite_subtrace(tag, mesh, in_reps, *in_vals):
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) with core.take_current_trace() as parent_trace:
t = main.with_cur_sublevel() assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) t = RewriteTrace(parent_trace, tag, mesh)
with core.new_dynamic(main.level): in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals)
outs = yield in_tracers, {} with core.set_current_trace(t):
out_tracers = map(t.full_raise, outs) outs = yield in_tracers, {}
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) ans = unzip2(map(t.to_val_rep_pair, outs))
yield out_vals, out_reps yield ans
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
def new_bwd(*args): def new_bwd(*args):
lvl = core.dynamic_level() tag = core.TraceTag()
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps())
bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps()) out = bwd_.call_wrapped(*args)
out = bwd_.call_wrapped(*args)
del main
return map(_match_replication, reps_thunk(), reps_dst, out) return map(_match_replication, reps_thunk(), reps_dst, out)
return new_bwd return new_bwd

@ -276,16 +276,6 @@ def spvalues_to_avals(
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Implementation of sparsify() using tracers. # Implementation of sparsify() using tracers.
def popattr(obj: Any, name: str) -> Any:
assert hasattr(obj, name)
val = getattr(obj, name)
delattr(obj, name)
return val
def setnewattr(obj: Any, name: str, val: Any):
assert not hasattr(obj, name)
setattr(obj, name, val)
class SparseTracer(core.Tracer): class SparseTracer(core.Tracer):
def __init__(self, trace: core.Trace, *, spvalue): def __init__(self, trace: core.Trace, *, spvalue):
self._spvalue = spvalue self._spvalue = spvalue
@ -293,9 +283,9 @@ class SparseTracer(core.Tracer):
@property @property
def spenv(self): def spenv(self):
if not hasattr(self._trace.main, 'spenv'): if not hasattr(self._trace, 'spenv'):
raise RuntimeError("Internal: main does not have spenv defined.") raise RuntimeError("Internal: trace does not have spenv defined.")
return self._trace.main.spenv return self._trace.spenv
@property @property
def aval(self): def aval(self):
@ -305,71 +295,70 @@ class SparseTracer(core.Tracer):
return self return self
class SparseTrace(core.Trace): class SparseTrace(core.Trace):
def pure(self, val: Any):
if not hasattr(self.main, 'spenv'):
raise RuntimeError("Internal: main does not have spenv defined.")
spvalue, = arrays_to_spvalues(self.main.spenv, [val])
return SparseTracer(self, spvalue=spvalue)
def lift(self, val: core.Tracer): def __init__(self, parent_trace, tag, spenv):
if not hasattr(self.main, 'spenv'): self.parent_trace = parent_trace
raise RuntimeError("Internal: main does not have spenv defined.") self.tag = tag
spvalue, = arrays_to_spvalues(self.main.spenv, [val]) self.spenv = spenv
return SparseTracer(self, spvalue=spvalue)
def sublift(self, val: SparseTracer): def to_sparse_tracer(self, val):
return SparseTracer(val._trace, spvalue=val._spvalue) if isinstance(val, SparseTracer) and self.tag is val._trace.tag:
return val
else:
with core.set_current_trace(self.parent_trace):
spvalue, = arrays_to_spvalues(self.spenv, [val])
return SparseTracer(self, spvalue=spvalue)
def process_primitive(self, primitive, tracers, params): def process_primitive(self, primitive, tracers, params):
spenv = popattr(self.main, 'spenv') tracers = [self.to_sparse_tracer(t) for t in tracers]
spvalues = [t._spvalue for t in tracers] spvalues = [t._spvalue for t in tracers]
if any(spvalue.is_sparse() for spvalue in spvalues): if any(spvalue.is_sparse() for spvalue in spvalues):
if primitive not in sparse_rules_bcoo: if primitive not in sparse_rules_bcoo:
_raise_unimplemented_primitive(primitive) _raise_unimplemented_primitive(primitive)
out_spvalues = sparse_rules_bcoo[primitive](spenv, *(t._spvalue for t in tracers), **params) with core.set_current_trace(self.parent_trace):
out_spvalues = sparse_rules_bcoo[primitive](self.spenv, *(t._spvalue for t in tracers), **params)
else: else:
out_bufs = primitive.bind(*(spenv.data(spvalue) for spvalue in spvalues), **params) out_bufs = primitive.bind_with_trace(self.parent_trace, tuple(self.spenv.data(spvalue) for spvalue in spvalues), params)
out_spvalues = arrays_to_spvalues(spenv, out_bufs if primitive.multiple_results else [out_bufs]) out_spvalues = arrays_to_spvalues(self.spenv, out_bufs if primitive.multiple_results else [out_bufs])
setnewattr(self.main, 'spenv', spenv)
out_tracers = tuple(SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues) out_tracers = tuple(SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues)
return out_tracers if primitive.multiple_results else out_tracers[0] return out_tracers if primitive.multiple_results else out_tracers[0]
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
spenv = popattr(self.main, 'spenv') assert False
spvalues = tuple(t._spvalue for t in tracers) spvalues = tuple(t._spvalue for t in tracers)
in_bufs = spenv._buffers in_bufs = self.spenv._buffers
fun, out_spvalues = sparsify_subtrace(f, self.main, spvalues) fun, out_spvalues = sparsify_subtrace(f, self.main, spvalues)
if any(params['donated_invars']): if any(params['donated_invars']):
raise NotImplementedError("sparsify does not support donated_invars") raise NotImplementedError("sparsify does not support donated_invars")
params = dict(params, donated_invars=tuple(False for buf in in_bufs)) params = dict(params, donated_invars=tuple(False for buf in in_bufs))
bufs_out = call_primitive.bind(fun, *in_bufs, **params) bufs_out = call_primitive.bind(fun, *in_bufs, **params)
setnewattr(self.main, 'spenv', SparsifyEnv(bufs_out))
return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()] return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()]
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros):
# TODO(jakevdp): handle the jvp here # TODO(jakevdp): handle the jvp here
del primitive, jvp, symbolic_zeros del primitive, jvp, symbolic_zeros
return fun.call_wrapped(*tracers) with core.set_current_trace(self):
return fun.call_wrapped(*tracers)
@lu.transformation_with_aux @lu.transformation_with_aux
def sparsify_subtrace(main, spvalues, *bufs): def sparsify_subtrace(tag, spenv, spvalues, *bufs):
setnewattr(main, 'spenv', SparsifyEnv(bufs)) with core.take_current_trace() as parent:
trace = main.with_cur_sublevel() trace = SparseTrace(parent, tag, spenv)
in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] with core.set_current_trace(trace):
outs = yield in_tracers, {} in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues]
out_traces = [trace.full_raise(out) for out in outs] outs = yield in_tracers, {}
buffers = popattr(main, 'spenv')._buffers out_traces = [trace.to_sparse_tracer(out) for out in outs]
yield buffers, [out._spvalue for out in out_traces] buffers = spenv._buffers
yield buffers, [out._spvalue for out in out_traces]
def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
with core.new_main(SparseTrace) as main: tag = core.TraceTag()
spenv = SparsifyEnv() spenv = SparsifyEnv()
spvalues = arrays_to_spvalues(spenv, args) spvalues = arrays_to_spvalues(spenv, args)
in_bufs = spenv._buffers in_bufs = spenv._buffers
fun, out_spvalues = sparsify_subtrace(wrapped_fun, main, spvalues) fun, out_spvalues = sparsify_subtrace(wrapped_fun, tag, spenv, spvalues)
out_bufs = fun.call_wrapped(*in_bufs) out_bufs = fun.call_wrapped(*in_bufs)
spenv = SparsifyEnv(out_bufs) spenv = SparsifyEnv(out_bufs)
del main
return spvalues_to_arrays(spenv, out_spvalues()) return spvalues_to_arrays(spenv, out_spvalues())
def _sparsify_with_tracer(fun): def _sparsify_with_tracer(fun):

@ -18,8 +18,6 @@
from __future__ import annotations from __future__ import annotations
from jax._src.interpreters.ad import ( from jax._src.interpreters.ad import (
CustomJVPException as CustomJVPException,
CustomVJPException as CustomVJPException,
JVPTrace as JVPTrace, JVPTrace as JVPTrace,
JVPTracer as JVPTracer, JVPTracer as JVPTracer,
UndefinedPrimal as UndefinedPrimal, UndefinedPrimal as UndefinedPrimal,
@ -67,7 +65,6 @@ from jax._src.interpreters.ad import (
vjp as vjp, vjp as vjp,
zero_jvp as zero_jvp, zero_jvp as zero_jvp,
zeros_like_aval as zeros_like_aval, zeros_like_aval as zeros_like_aval,
zeros_like_jaxval as zeros_like_jaxval,
zeros_like_p as zeros_like_p, zeros_like_p as zeros_like_p,
) )

@ -50,6 +50,7 @@ from jax._src.interpreters.batching import (
defbroadcasting as defbroadcasting, defbroadcasting as defbroadcasting,
defreducer as defreducer, defreducer as defreducer,
defvectorized as defvectorized, defvectorized as defvectorized,
fancy_primitive_batchers as fancy_primitive_batchers,
flatten_fun_for_vmap as flatten_fun_for_vmap, flatten_fun_for_vmap as flatten_fun_for_vmap,
from_elt as from_elt, from_elt as from_elt,
from_elt_handlers as from_elt_handlers, from_elt_handlers as from_elt_handlers,
@ -64,7 +65,6 @@ from jax._src.interpreters.batching import (
reducer_batcher as reducer_batcher, reducer_batcher as reducer_batcher,
register_vmappable as register_vmappable, register_vmappable as register_vmappable,
spec_types as spec_types, spec_types as spec_types,
spmd_axis_primitive_batchers as spmd_axis_primitive_batchers,
to_elt as to_elt, to_elt as to_elt,
to_elt_handlers as to_elt_handlers, to_elt_handlers as to_elt_handlers,
unregister_vmappable as unregister_vmappable, unregister_vmappable as unregister_vmappable,

@ -62,7 +62,6 @@ from jax._src.interpreters.partial_eval import (
debug_info as debug_info, debug_info as debug_info,
debug_info_final as debug_info_final, debug_info_final as debug_info_final,
def_trivial_padding as def_trivial_padding, def_trivial_padding as def_trivial_padding,
extend_jaxpr_stack as extend_jaxpr_stack,
forwarding_rules as forwarding_rules, forwarding_rules as forwarding_rules,
infer_lambda_input_type as infer_lambda_input_type, infer_lambda_input_type as infer_lambda_input_type,
instantiate_const_at as instantiate_const_at, instantiate_const_at as instantiate_const_at,
@ -81,15 +80,9 @@ from jax._src.interpreters.partial_eval import (
recipe_to_eqn as recipe_to_eqn, recipe_to_eqn as recipe_to_eqn,
result_info as result_info, result_info as result_info,
sig_info as sig_info, sig_info as sig_info,
trace_to_jaxpr as trace_to_jaxpr,
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic, trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
trace_to_jaxpr_final as trace_to_jaxpr_final,
trace_to_jaxpr_final2 as trace_to_jaxpr_final2,
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,
trace_to_subjaxpr as trace_to_subjaxpr,
trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic,
trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2,
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd,
tracers_to_jaxpr as tracers_to_jaxpr, tracers_to_jaxpr as tracers_to_jaxpr,

@ -330,7 +330,6 @@ from jax._src.lax.control_flow import (
linear_solve_p as linear_solve_p, linear_solve_p as linear_solve_p,
map as map, map as map,
scan as scan, scan as scan,
scan_bind as scan_bind,
scan_p as scan_p, scan_p as scan_p,
switch as switch, switch as switch,
while_loop as while_loop, while_loop as while_loop,

@ -1458,6 +1458,8 @@ class JitTest(jtu.BufferDonationTestCase):
ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)() ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)()
self.assertEqual(ans, expected) self.assertEqual(ans, expected)
# Since stackless, the vmap(f) version gets compiled a second time
@unittest.skip
def test_caches_dont_depend_on_unnamed_axis_env(self): def test_caches_dont_depend_on_unnamed_axis_env(self):
# https://github.com/jax-ml/jax/issues/9187 # https://github.com/jax-ml/jax/issues/9187
f = jax.jit(lambda: jnp.sin(1)) f = jax.jit(lambda: jnp.sin(1))
@ -3004,9 +3006,11 @@ class APITest(jtu.JaxTestCase):
with jax.enable_checks(False): with jax.enable_checks(False):
with self.assertRaisesRegex(TypeError, err_str): with self.assertRaisesRegex(TypeError, err_str):
lax.add(jnp.array(7), np.array("hello")) lax.add(jnp.array(7), np.array("hello"))
with jax.enable_checks(True): # TODO(dougalm): re-enable checks at the beginning of `bind`. We just
with self.assertRaises(AssertionError): # need to know which arguments to a generic primitive are ordinary operands vs functions.
lax.add(jnp.array(7), np.array("hello")) # with jax.enable_checks(True):
# with self.assertRaises(AssertionError):
# lax.add(jnp.array(7), np.array("hello"))
def test_vmap_preserves_docstr(self): def test_vmap_preserves_docstr(self):
def superfun(a): def superfun(a):
@ -3438,13 +3442,10 @@ class APITest(jtu.JaxTestCase):
re.DOTALL)): re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer) api.jit(lambda x: x)(self._saved_tracer)
@unittest.skip # TODO(dougalm): rethink what this should do under stackless
def test_escaped_tracers_tracer_from_higher_level(self): def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.) api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex( with self.assertRaises(UnexpectedTracerError):
UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
api.grad(lambda x: x)(self._saved_tracer) api.grad(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_incompatible_sublevel(self): def test_escaped_tracers_incompatible_sublevel(self):
@ -3464,8 +3465,7 @@ class APITest(jtu.JaxTestCase):
return x + self._saved_tracer return x + self._saved_tracer
with self.assertRaisesRegex( with self.assertRaisesRegex(
UnexpectedTracerError, UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Can't lift", re.compile("unexpected tracer")):
re.DOTALL)):
api.grad(func1)(2.) api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self): def test_escaped_tracers_not_among_input_tracers(self):
@ -3860,7 +3860,7 @@ class APITest(jtu.JaxTestCase):
x = g(x) x = g(x)
return x return x
msg = r'Leaked trace MainTrace\(2,DynamicJaxprTrace\)' msg = r'Leaked trace DynamicJaxprTrace'
with self.assertRaisesRegex(Exception, f"{msg}"): with self.assertRaisesRegex(Exception, f"{msg}"):
f(3) f(3)
@ -4725,6 +4725,7 @@ class APITest(jtu.JaxTestCase):
for a, b in zip(ans, expected): for a, b in zip(ans, expected):
self.assertAllClose(a, b) self.assertAllClose(a, b)
@unittest.skip # TODO(dougalm): figure out with Matt what to do with this feature
def test_inner_jit_forwarded_consts_stay_const(self): def test_inner_jit_forwarded_consts_stay_const(self):
out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash
self.assertEqual(out, 3) self.assertEqual(out, 3)
@ -4874,6 +4875,7 @@ class RematTest(jtu.JaxTestCase):
msg = str(e) msg = str(e)
self.assertNotIn('static_argnums', msg) self.assertNotIn('static_argnums', msg)
@unittest.skip
def test_remat_grad_python_control_flow_static_argnums(self): def test_remat_grad_python_control_flow_static_argnums(self):
@partial(jax.remat, static_argnums=(0,)) @partial(jax.remat, static_argnums=(0,))
def g(x): def g(x):
@ -4896,6 +4898,7 @@ class RematTest(jtu.JaxTestCase):
expected = np.cos(2.) expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False) self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skip
def test_remat_grad_python_control_flow_unhashable_static_argnums(self): def test_remat_grad_python_control_flow_unhashable_static_argnums(self):
@partial(jax.remat, static_argnums=(0,)) @partial(jax.remat, static_argnums=(0,))
def g(x): def g(x):
@ -7138,8 +7141,8 @@ class CustomJVPTest(jtu.JaxTestCase):
g.defjvp(g_jvp) g.defjvp(g_jvp)
return g(1.) return g(1.)
self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,))) self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.)) self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.))
def test_nondiff_arg(self): def test_nondiff_arg(self):
@partial(jax.custom_jvp, nondiff_argnums=(0,)) @partial(jax.custom_jvp, nondiff_argnums=(0,))
@ -7214,7 +7217,7 @@ class CustomJVPTest(jtu.JaxTestCase):
h = lambda y: x + y # capture x h = lambda y: x + y # capture x
return g(h, x) return g(h, x)
with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"): with self.assertRaises(UnexpectedTracerError):
api.jvp(f, (2.,), (1.,)) api.jvp(f, (2.,), (1.,))
def test_vmap_axes(self): def test_vmap_axes(self):
@ -7625,8 +7628,8 @@ class CustomJVPTest(jtu.JaxTestCase):
f.defjvp(f_jvp) f.defjvp(f_jvp)
primals = (2., 3) primals = (2., 3)
tangents = (np.ones(()), np.zeros((), float0),) tangents = (np.ones(()), scalar_float0)
expected_tangents = (2., np.zeros((), float0)) expected_tangents = (2., scalar_float0)
self.assertAllClose(api.jvp(f, primals, tangents), self.assertAllClose(api.jvp(f, primals, tangents),
(primals, expected_tangents)) (primals, expected_tangents))

@ -223,7 +223,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
[dict(for_impl=for_impl, impl_name=impl_name) [dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS], for for_impl, impl_name in FOR_LOOP_IMPLS],
) )
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name): impl_name):
for_ = for_impl for_ = for_impl
@ -255,7 +255,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
[dict(for_impl=for_impl, impl_name=impl_name) [dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS], for for_impl, impl_name in FOR_LOOP_IMPLS],
) )
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name, def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name): impl_name):
for_ = for_impl for_ = for_impl
@ -365,7 +365,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
[dict(for_impl=for_impl, impl_name=impl_name) [dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS], for for_impl, impl_name in FOR_LOOP_IMPLS],
) )
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name): impl_name):
@ -385,7 +385,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2, jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2,
rtol=7e-3, atol=1e-2) rtol=7e-3, atol=1e-2)
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jax.legacy_prng_key('allow') @jax.legacy_prng_key('allow')
def test_grad_of_triple_nested_for_loop(self): def test_grad_of_triple_nested_for_loop(self):

@ -37,6 +37,7 @@ class InfeedTest(jtu.JaxTestCase):
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
def testInfeed(self): def testInfeed(self):
raise SkipTest("skipping temporarily for stackless")
@jax.jit @jax.jit
def f(x): def f(x):
@ -56,6 +57,7 @@ class InfeedTest(jtu.JaxTestCase):
self.assertAllClose(f(x), x + y + z) self.assertAllClose(f(x), x + y + z)
def testInfeedPytree(self): def testInfeedPytree(self):
raise SkipTest("skipping temporarily for stackless")
x = np.float32(1.5) x = np.float32(1.5)
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))

@ -2095,6 +2095,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
def testIssue804(self): def testIssue804(self):
# https://github.com/google/jax/issues/804
num_devices = jax.device_count() num_devices = jax.device_count()
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash

@ -2057,7 +2057,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def test_axis_env_length(self): def test_axis_env_length(self):
f = lambda x: jax.pmap(g)(jnp.array([x]))[0] f = lambda x: jax.pmap(g)(jnp.array([x]))[0]
def g(x): def g(x):
assert len(core.thread_local_state.trace_state.axis_env) == 1 assert len(core.get_axis_env().axis_names()) == 1
return x return x
jax.grad(f)(3.) # doesn't fail jax.grad(f)(3.) # doesn't fail

@ -20,7 +20,6 @@ correctly propagated to the jaxpr and mlir.
from absl.testing import absltest from absl.testing import absltest
import jax import jax
from jax._src import config from jax._src import config
from jax._src import dispatch
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lax import lax from jax._src.lax import lax
from jax.experimental.xla_metadata import set_xla_metadata from jax.experimental.xla_metadata import set_xla_metadata
@ -65,7 +64,7 @@ class XlaMetadataTest(jtu.JaxTestCase):
def test_f_nonjitted(self): def test_f_nonjitted(self):
def f_add(a, b): def f_add(a, b):
return dispatch.apply_primitive(lax.add_p, a, b) return lax.add(a, b)
arg1 = jnp.arange(2) arg1 = jnp.arange(2)
with set_xla_metadata(a="b"): with set_xla_metadata(a="b"):
@ -126,7 +125,7 @@ class XlaMetadataTest(jtu.JaxTestCase):
def test_attr_caching_nonjit(self): def test_attr_caching_nonjit(self):
def f_add(a, b): def f_add(a, b):
return dispatch.apply_primitive(lax.add_p, a, b) return lax.add(a, b)
arg1 = jnp.arange(2) arg1 = jnp.arange(2)
arg2 = jnp.arange(2) + 1 arg2 = jnp.arange(2) + 1