mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
remove vestigial ad.reducing_transposes table
these were an xmap / avals-with-names named axis thing, but that stuff is gone so we can simplify
This commit is contained in:
parent
101168740e
commit
6172a1f1d5
@ -652,7 +652,7 @@ def remat_transpose(out_cts, *in_primals, jaxpr, **params):
|
||||
for x in in_primals]
|
||||
assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
|
||||
return in_cts
|
||||
ad.reducing_transposes[remat_p] = remat_transpose
|
||||
ad.primitive_transposes[remat_p] = remat_transpose
|
||||
|
||||
# TODO(mattjj): move this to ad.py
|
||||
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
|
||||
|
@ -277,9 +277,6 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals)
|
||||
elif eqn.primitive in reducing_transposes:
|
||||
cts_out = reducing_transposes[eqn.primitive](
|
||||
cts_in, *invals, **eqn.params)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
cts_in, *invals, **eqn.params)
|
||||
@ -586,8 +583,6 @@ class LinearizeTracer(Tracer):
|
||||
|
||||
primitive_jvps : dict[core.Primitive, Callable] = {}
|
||||
primitive_transposes: dict[core.Primitive, Callable] = {}
|
||||
# transpose rules that internally perform reductions over the given named axes
|
||||
reducing_transposes: dict[core.Primitive, Callable] = {}
|
||||
primitive_linearizations : dict[core.Primitive, Callable] = {}
|
||||
|
||||
def deflinear(primitive, transpose_rule):
|
||||
@ -871,3 +866,6 @@ class CustomVJPException(Exception):
|
||||
"closed-over value into the custom_vjp function as an argument, and "
|
||||
"adapting the custom_vjp fwd and bwd rules.")
|
||||
super().__init__(msg)
|
||||
|
||||
# TODO(mattjj): remove this vestigial dict
|
||||
reducing_transposes: dict[core.Primitive, Callable] = {}
|
||||
|
@ -780,7 +780,7 @@ cond_p.multiple_results = True
|
||||
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
|
||||
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
||||
ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
ad.primitive_transposes[cond_p] = _cond_transpose
|
||||
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.register_initial_style_primitive(cond_p)
|
||||
|
@ -1228,7 +1228,7 @@ scan_p.multiple_results = True
|
||||
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
|
||||
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
ad.primitive_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
xla.register_initial_style_primitive(scan_p)
|
||||
mlir.register_lowering(scan_p,
|
||||
|
@ -2385,7 +2385,7 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
_set_states(attrs_tracked, final_states)
|
||||
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
ad.primitive_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
|
Loading…
x
Reference in New Issue
Block a user