[remove-units] avoid unit-generating function in jax.linear_transpose

This commit is contained in:
Matthew Johnson 2022-04-29 16:36:57 -07:00
parent fb4731d40e
commit 65bff3c856

View File

@ -2515,8 +2515,8 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
in_dtypes = map(dtypes.dtype, in_avals)
in_pvals = map(pe.PartialVal.unknown, in_avals)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
instantiate=True)
jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals,
instantiate=True)
out_avals, _ = unzip2(out_pvals)
out_dtypes = map(dtypes.dtype, out_avals)
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
@ -2527,22 +2527,21 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
f"but got {in_dtypes} -> {out_dtypes}.")
@api_boundary
def transposed_fun(consts, out_cotangent):
out_cotangents, out_tree2 = tree_flatten(out_cotangent)
def transposed_fun(const, out_cotangent):
out_cts, out_tree2 = tree_flatten(out_cotangent)
if out_tree() != out_tree2:
raise TypeError("cotangent tree does not match function output, "
f"expected {out_tree()} but got {out_tree2}")
if not all(map(core.typecheck, out_avals, out_cotangents)):
if not all(map(core.typecheck, out_avals, out_cts)):
raise TypeError("cotangent type does not match function output, "
f"expected {out_avals} but got {out_cotangents}")
f"expected {out_avals} but got {out_cts}")
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)
in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
in_cts = map(ad.instantiate_zeros, in_cts)
return tree_unflatten(in_tree, in_cts)
# Ensure that transposed_fun is a PyTree
return Partial(transposed_fun, consts)
return Partial(transposed_fun, const)
def make_jaxpr(fun: Callable,