remove vestigial ad.reducing_transposes table

these were an xmap / avals-with-names named axis thing, but that stuff is gone so we can simplify
This commit is contained in:
Matthew Johnson 2024-12-05 05:44:40 +00:00
parent 101168740e
commit 6172a1f1d5
5 changed files with 7 additions and 9 deletions

View File

@ -652,7 +652,7 @@ def remat_transpose(out_cts, *in_primals, jaxpr, **params):
for x in in_primals]
assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
return in_cts
ad.reducing_transposes[remat_p] = remat_transpose
ad.primitive_transposes[remat_p] = remat_transpose
# TODO(mattjj): move this to ad.py
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],

View File

@ -277,9 +277,6 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
call_jaxpr = params.pop('call_jaxpr')
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals)
elif eqn.primitive in reducing_transposes:
cts_out = reducing_transposes[eqn.primitive](
cts_in, *invals, **eqn.params)
else:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
@ -586,8 +583,6 @@ class LinearizeTracer(Tracer):
primitive_jvps : dict[core.Primitive, Callable] = {}
primitive_transposes: dict[core.Primitive, Callable] = {}
# transpose rules that internally perform reductions over the given named axes
reducing_transposes: dict[core.Primitive, Callable] = {}
primitive_linearizations : dict[core.Primitive, Callable] = {}
def deflinear(primitive, transpose_rule):
@ -871,3 +866,6 @@ class CustomVJPException(Exception):
"closed-over value into the custom_vjp function as an argument, and "
"adapting the custom_vjp fwd and bwd rules.")
super().__init__(msg)
# TODO(mattjj): remove this vestigial dict
reducing_transposes: dict[core.Primitive, Callable] = {}

View File

@ -780,7 +780,7 @@ cond_p.multiple_results = True
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
ad.primitive_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)

View File

@ -1228,7 +1228,7 @@ scan_p.multiple_results = True
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
ad.primitive_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p)
mlir.register_lowering(scan_p,

View File

@ -2385,7 +2385,7 @@ def _pjit_transpose(cts_in, *primals_in,
_set_states(attrs_tracked, final_states)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose
ad.primitive_transposes[pjit_p] = _pjit_transpose
@weakref_lru_cache