mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
ship it
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
0988f6d8d5
commit
4c2ec3e442
@ -57,7 +57,7 @@ def zeros_like_aval(aval):
|
||||
aval_zeros_likers = {}
|
||||
|
||||
def zeros_like_abstract_tuple(tup):
|
||||
return AbstractTuple(map(zeros_like_aval, tup))
|
||||
return JaxTuple(map(zeros_like_aval, tup))
|
||||
aval_zeros_likers[AbstractTuple] = zeros_like_abstract_tuple
|
||||
|
||||
|
||||
|
12
jax/core.py
12
jax/core.py
@ -67,11 +67,12 @@ class TypedJaxpr(namedtuple('TypedJaxpr', ['jaxpr', 'literals', 'in_avals', 'out
|
||||
|
||||
@curry
|
||||
def jaxpr_as_fun(typed_jaxpr, *args):
|
||||
from jax.lax import _abstractify
|
||||
for arg, in_aval in zip(args, typed_jaxpr.in_avals):
|
||||
from jax.lax import _abstractify # TODO
|
||||
invars = typed_jaxpr.jaxpr.invars
|
||||
for arg, in_aval, varname in zip(args, typed_jaxpr.in_avals, invars):
|
||||
arg_aval, _ = _abstractify(arg)
|
||||
if arg_aval != in_aval:
|
||||
raise TypeError("input type mismatch")
|
||||
raise TypeError("input type mismatch for arg {}".format(varname))
|
||||
out = eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, (), *args)
|
||||
out_aval, _ = _abstractify(out)
|
||||
if out_aval != typed_jaxpr.out_aval:
|
||||
@ -157,7 +158,10 @@ def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
|
||||
# TODO enforce a specific set of types for jaxpr vars
|
||||
def pat_fmap(f, v, *xs):
|
||||
if type(v) in (tuple, list):
|
||||
return tuple(map(partial(pat_fmap, f), v, *xs))
|
||||
if len(xs) == 1 and xs[0] is None:
|
||||
return tuple(map(partial(pat_fmap, f), v, [None] * len(v)))
|
||||
else:
|
||||
return tuple(map(partial(pat_fmap, f), v, *xs))
|
||||
else:
|
||||
return f(v, *xs)
|
||||
|
||||
|
@ -186,9 +186,17 @@ def _scan_initial_impl(consts, init, xs, forward, length, jaxpr):
|
||||
return (carry_out, ys_out)
|
||||
|
||||
ys_init = empty_arrays(ys_aval)
|
||||
carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
|
||||
# carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
|
||||
carry, ys = fori_loop(0, length, body_fun, (init, ys_init))
|
||||
return core.pack((carry, ys))
|
||||
|
||||
# TODO remove
|
||||
def fori_loop(start, stop, body_fun, init_val):
|
||||
carry = init_val
|
||||
for i in range(start, stop):
|
||||
carry = body_fun(i, carry)
|
||||
return carry
|
||||
|
||||
|
||||
def _scan_initial_jvp(primals, tangents, forward, length, jaxpr):
|
||||
consts, init, xs = primals
|
||||
@ -322,22 +330,31 @@ def _scan_initial_transpose(ct, consts, init, xs, forward, length, jaxpr):
|
||||
# jaxpr :: d -> c -> (a, res) -> (c, b) # TODO assuming restructuring input
|
||||
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
|
||||
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
|
||||
jaxpr_lifted = rearrange_binders(
|
||||
lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
|
||||
jaxpr_lifted_trans, out_tree = transpose_jaxpr2(jaxpr_lifted)
|
||||
jaxpr_lifted_trans = transpose_jaxpr2(jaxpr_lifted)
|
||||
jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
ct_c, ct_bs = ct
|
||||
carry_ct = core.pack((ct_c, ad_util.zeros_like_aval(jaxpr.in_avals[0])))
|
||||
c_aval, b_aval = jaxpr.out_aval
|
||||
d_aval, c_aval2, _ = jaxpr.in_avals
|
||||
assert c_aval == c_aval2
|
||||
bs_aval = promote_aval_rank(length, b_aval)
|
||||
ct_d = ad_util.zeros_like_aval(d_aval)
|
||||
ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct)
|
||||
carry_ct = core.pack((ct_c, ct_d))
|
||||
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
|
||||
core.check_jaxpr(jaxpr_trans.jaxpr)
|
||||
unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
|
||||
assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
|
||||
assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval
|
||||
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
|
||||
# scan_p.bind :: (d -> c -> a -> (c, b)) -> d -> c -> [a] -> (c, [b])
|
||||
out = scan_initial_p.bind(
|
||||
core.unit, carry_ct, core.pack((ct_bs, res)),
|
||||
forward=not forward, length=length, jaxpr=jaxpr_trans)
|
||||
import ipdb; ipdb.set_trace()
|
||||
(ct_init, ct_consts), ct_as = out
|
||||
return ct_consts, ct_init, (ct_as, None)
|
||||
|
||||
def rearrange_binders(f, typed_jaxpr):
|
||||
jaxpr = typed_jaxpr.jaxpr.copy()
|
||||
@ -351,21 +368,24 @@ _scan_newvar = pe.gensym('_scan')
|
||||
|
||||
def _move_stuff_and_add_add(typed_jaxpr):
|
||||
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
|
||||
|
||||
res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals
|
||||
CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval
|
||||
assert CTc_aval == CTc_aval2
|
||||
in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)),
|
||||
core.AbstractTuple((CTb_aval, res_aval)))
|
||||
out_aval = core.AbstractTuple((core.AbstractTuple((CTd_aval, CTc_aval)),
|
||||
out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)),
|
||||
CTa_aval))
|
||||
|
||||
jaxpr = typed_jaxpr.jaxpr.copy()
|
||||
# TODO assume not restructuring input
|
||||
assert not any(type(invar) is tuple for invar in jaxpr.invars)
|
||||
|
||||
# munge input side
|
||||
CTc_in = _scan_newvar()
|
||||
CTb_in = _scan_newvar()
|
||||
CTd_in = _scan_newvar()
|
||||
res_in, CTc_CTb_in = jaxpr.invars
|
||||
jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in))
|
||||
jaxpr.eqns = (
|
||||
@ -373,7 +393,6 @@ def _move_stuff_and_add_add(typed_jaxpr):
|
||||
jaxpr.eqns)
|
||||
|
||||
# munge output side
|
||||
CTd_in = _scan_newvar()
|
||||
CTd_new = _scan_newvar()
|
||||
CTd_sum = _scan_newvar()
|
||||
CTc = _scan_newvar()
|
||||
@ -384,33 +403,32 @@ def _move_stuff_and_add_add(typed_jaxpr):
|
||||
jaxpr.eqns +
|
||||
[_unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]),
|
||||
_add_any_eqn(CTd_sum, CTd_new, CTd_in),
|
||||
_pack_eqn([CTd_sum, CTc], partial_out),
|
||||
_pack_eqn([CTc, CTd_sum], partial_out),
|
||||
_pack_eqn([partial_out, CTa], outvar)])
|
||||
jaxpr.outvar = outvar
|
||||
|
||||
# TODO should really have a check_typed_jaxpr
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
return core.TypedJaxpr(jaxpr, typed_jaxpr.literals,
|
||||
in_avals, out_aval)
|
||||
|
||||
def _add_any_eqn(tot, a, b):
|
||||
return core.JaxprEqn([a, b], [tot], ad_util.add_jaxvals_p, (), False, False, {})
|
||||
|
||||
|
||||
# transpose_jaxpr :: (res -> a -> b) -> (res -> CT b -> CT a)
|
||||
def transpose_jaxpr2(jaxpr):
|
||||
assert len(jaxpr.in_avals) == 2
|
||||
nones = core.pat_fmap(lambda _: None, jaxpr.jaxpr.invars[1])
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(res, b_bar):
|
||||
_, a_bar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (),
|
||||
(res, nones), b_bar)
|
||||
_, (_, a_bar) = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (),
|
||||
(res, None), b_bar)
|
||||
a_bar = ad.instantiate_zeros_aval(jaxpr.in_avals[1], a_bar)
|
||||
return a_bar
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_out(*args):
|
||||
ans = yield args
|
||||
yield pytree_to_jaxtupletree(ans)
|
||||
|
||||
transposed, out_tree = flatten_out(transposed)
|
||||
transposed_jaxpr = make_typed_jaxpr(transposed, (jaxpr.in_avals[0], jaxpr.out_aval))
|
||||
return transposed_jaxpr, out_tree()
|
||||
return transposed_jaxpr
|
||||
|
||||
def make_typed_jaxpr(traceable, in_avals):
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
|
||||
|
@ -21,7 +21,7 @@ import itertools as it
|
||||
from . import partial_eval as pe
|
||||
from .. import core as core
|
||||
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
|
||||
@ -185,12 +185,29 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
|
||||
cts_out = [zero for _ in eqn.invars]
|
||||
map(write_cotangent, eqn.invars, cts_out)
|
||||
|
||||
cotangents_out = core.pat_fmap(
|
||||
lambda var, argval: read_cotangent(var) if argval is None else None,
|
||||
jaxpr.invars, args)
|
||||
freevar_cts = core.pat_fmap(read_cotangent, jaxpr.freevars)
|
||||
cotangents_out = core.pat_fmap(lambda v, _: read_cotangent(v), jaxpr.invars, None)
|
||||
cotangents_out = tuple(map(pack_cotangents_like_caller, args, cotangents_out))
|
||||
return freevar_cts, cotangents_out
|
||||
|
||||
def pack_cotangents_like_caller(arg, ct):
|
||||
if type(arg) is tuple:
|
||||
return tuple(map(pack_cotangents_like_caller, arg, ct))
|
||||
elif arg is None:
|
||||
return recursively_pack(ct)
|
||||
else:
|
||||
return None
|
||||
|
||||
def recursively_pack(ct):
|
||||
if type(ct) is tuple:
|
||||
ct = tuple(map(recursively_pack, ct))
|
||||
if any(elt is zero or isinstance(elt, TangentTuple) for elt in ct):
|
||||
return TangentTuple(ct)
|
||||
else:
|
||||
return pack(ct)
|
||||
else:
|
||||
return ct
|
||||
|
||||
def get_primitive_transpose(p):
|
||||
try:
|
||||
return primitive_transposes[p]
|
||||
@ -409,6 +426,14 @@ def instantiate_zeros(example, tangent):
|
||||
else:
|
||||
return tangent
|
||||
|
||||
def instantiate_zeros_aval(aval, tangent):
|
||||
if tangent is zero:
|
||||
return zeros_like_aval(aval)
|
||||
elif isinstance(tangent, TangentTuple):
|
||||
return pack(map(instantiate_zeros_aval, aval, tangent))
|
||||
else:
|
||||
return tangent
|
||||
|
||||
@transformation_with_aux
|
||||
def traceable(in_tree_def, new_primals, new_tangents):
|
||||
new_tangents = build_tree(in_tree_def, new_tangents)
|
||||
|
@ -32,19 +32,19 @@ def f(c, a):
|
||||
as_ = np.ones((5, 3))
|
||||
c = np.ones(4)
|
||||
|
||||
# print scan_reference(f, c, as_)
|
||||
# print scan_initial(f, c, as_)
|
||||
# print
|
||||
print scan_reference(f, c, as_)
|
||||
print scan_initial(f, c, as_)
|
||||
print
|
||||
|
||||
# print jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))[1]
|
||||
# print jvp(lambda c, as_: scan_initial(f, c, as_), (c, as_), (c, as_))[1]
|
||||
# print
|
||||
print jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))[1]
|
||||
print jvp(lambda c, as_: scan_initial(f, c, as_), (c, as_), (c, as_))[1]
|
||||
print
|
||||
|
||||
# print linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
|
||||
# print linearize(lambda c, as_: scan_initial(f, c, as_), c, as_)[1](c, as_)
|
||||
# print
|
||||
print linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
|
||||
print linearize(lambda c, as_: scan_initial(f, c, as_), c, as_)[1](c, as_)
|
||||
print
|
||||
|
||||
# print grad(lambda c, as_: scan_reference(f, c, as_)[0].sum())(c, as_)
|
||||
print grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
|
||||
print grad(lambda c, as_: list(scan_initial(f, c, as_))[0].sum())(c, as_)
|
||||
print
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user