mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[remove-units] avoid unit-generating function in jax.linear_transpose
This commit is contained in:
parent
fb4731d40e
commit
65bff3c856
@ -2515,8 +2515,8 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
in_dtypes = map(dtypes.dtype, in_avals)
|
||||
|
||||
in_pvals = map(pe.PartialVal.unknown, in_avals)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
|
||||
instantiate=True)
|
||||
jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals,
|
||||
instantiate=True)
|
||||
out_avals, _ = unzip2(out_pvals)
|
||||
out_dtypes = map(dtypes.dtype, out_avals)
|
||||
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
|
||||
@ -2527,22 +2527,21 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
f"but got {in_dtypes} -> {out_dtypes}.")
|
||||
|
||||
@api_boundary
|
||||
def transposed_fun(consts, out_cotangent):
|
||||
out_cotangents, out_tree2 = tree_flatten(out_cotangent)
|
||||
def transposed_fun(const, out_cotangent):
|
||||
out_cts, out_tree2 = tree_flatten(out_cotangent)
|
||||
if out_tree() != out_tree2:
|
||||
raise TypeError("cotangent tree does not match function output, "
|
||||
f"expected {out_tree()} but got {out_tree2}")
|
||||
if not all(map(core.typecheck, out_avals, out_cotangents)):
|
||||
if not all(map(core.typecheck, out_avals, out_cts)):
|
||||
raise TypeError("cotangent type does not match function output, "
|
||||
f"expected {out_avals} but got {out_cotangents}")
|
||||
f"expected {out_avals} but got {out_cts}")
|
||||
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
|
||||
in_cotangents = map(
|
||||
ad.instantiate_zeros,
|
||||
ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents))
|
||||
return tree_unflatten(in_tree, in_cotangents)
|
||||
in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
|
||||
in_cts = map(ad.instantiate_zeros, in_cts)
|
||||
return tree_unflatten(in_tree, in_cts)
|
||||
|
||||
# Ensure that transposed_fun is a PyTree
|
||||
return Partial(transposed_fun, consts)
|
||||
return Partial(transposed_fun, const)
|
||||
|
||||
|
||||
def make_jaxpr(fun: Callable,
|
||||
|
Loading…
x
Reference in New Issue
Block a user