mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
1c9035efca
commit
a17f8e4ca8
@ -176,3 +176,10 @@ array_types = [onp.ndarray, onp.float64, onp.float32, onp.float16,
|
||||
for t in array_types:
|
||||
core.pytype_aval_mappings[t] = ConcreteArray
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
|
||||
def zeros_like_shaped_array(aval):
|
||||
assert isinstance(aval, ShapedArray)
|
||||
return onp.zeros(aval.shape, dtype=aval.dtype)
|
||||
|
||||
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
||||
|
@ -16,7 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .core import JaxTuple, lattice_join, Primitive
|
||||
from .core import JaxTuple, lattice_join, Primitive, AbstractTuple
|
||||
from .tree_util import register_pytree_node
|
||||
from .util import safe_map
|
||||
|
||||
@ -52,6 +52,15 @@ jaxval_zeros_likers = {}
|
||||
jaxval_zeros_likers[JaxTuple] = zeros_like_impl_jaxtuple
|
||||
|
||||
|
||||
def zeros_like_aval(aval):
|
||||
return aval_zeros_likers[type(aval)](aval)
|
||||
aval_zeros_likers = {}
|
||||
|
||||
def zeros_like_abstract_tuple(tup):
|
||||
return AbstractTuple(map(zeros_like_aval, tup))
|
||||
aval_zeros_likers[AbstractTuple] = zeros_like_abstract_tuple
|
||||
|
||||
|
||||
def zeros_like_jaxval(val):
|
||||
return zeros_like_p.bind(val)
|
||||
|
||||
|
15
jax/core.py
15
jax/core.py
@ -80,7 +80,8 @@ def jaxpr_as_fun(typed_jaxpr, *args):
|
||||
|
||||
|
||||
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
|
||||
'bound_subjaxprs', 'destructure', 'params'])
|
||||
'bound_subjaxprs', 'restructure',
|
||||
'destructure', 'params'])
|
||||
|
||||
class Primitive(object):
|
||||
def __init__(self, name):
|
||||
@ -137,7 +138,11 @@ def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
|
||||
map(write, jaxpr.invars, args)
|
||||
map(write, jaxpr.freevars, freevar_vals)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals = map(read, eqn.invars)
|
||||
if not eqn.restructure:
|
||||
in_vals = map(read, eqn.invars)
|
||||
else:
|
||||
in_vals = [pack(map(read, invars)) if type(invars) is tuple else read(invars)
|
||||
for invars in eqn.invars]
|
||||
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings),
|
||||
map(read, freevar_bindings))
|
||||
for subjaxpr, const_bindings, freevar_bindings
|
||||
@ -615,7 +620,11 @@ def check_jaxpr(jaxpr):
|
||||
map(write, jaxpr.freevars)
|
||||
map(write, jaxpr.invars)
|
||||
for eqn in jaxpr.eqns:
|
||||
map(read, eqn.invars)
|
||||
if not eqn.restructure:
|
||||
map(read, eqn.invars)
|
||||
else:
|
||||
[map(read, invar) if type(invar) is tuple else read(invar)
|
||||
for invar in eqn.invars]
|
||||
for subjaxpr, constvars, freevars in eqn.bound_subjaxprs:
|
||||
map(read, freevars)
|
||||
map(read_const, constvars)
|
||||
|
@ -11,6 +11,7 @@ from jax.lax import _abstractify, _unpack_eqn
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
from jax import ad_util
|
||||
|
||||
|
||||
def pvals_with_zeros(zero_components, aval):
|
||||
@ -98,7 +99,7 @@ def _call_initial_partial_eval(trace, *tracers, **kwargs):
|
||||
*in_consts, jaxpr=jaxpr_1, consts=consts_1)
|
||||
residual_tracers = core.pack(map(trace.new_instantiated_const, residuals))
|
||||
eqn = core.JaxprEqn((residual_tracers,) + tracers, None, call_initial_p, (),
|
||||
False, dict(jaxpr=jaxpr_2, consts=consts_2))
|
||||
False, False, dict(jaxpr=jaxpr_2, consts=consts_2))
|
||||
return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
|
||||
|
||||
|
||||
@ -150,7 +151,7 @@ def update_arrays(i, aval, xs, x):
|
||||
else:
|
||||
return lax.dynamic_update_index_in_dim(xs, x[None, ...], i, axis=0)
|
||||
|
||||
_scan_const = pe.gensym('_consts')
|
||||
_scan_newvar = pe.gensym('_scan')
|
||||
|
||||
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
||||
def scan_initial(f, init, xs):
|
||||
@ -162,26 +163,27 @@ def scan_initial(f, init, xs):
|
||||
lu.wrap_init(f), (carry_pval, x_pval), instantiate=True)
|
||||
(carry_aval_out, y_aval), _ = pval_out
|
||||
assert carry_aval == carry_aval_out
|
||||
lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr, _scan_const)
|
||||
lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr, _scan_newvar)
|
||||
consts_aval, _ = _abstractify(core.pack(consts))
|
||||
in_avals = (consts_aval, carry_aval, x_aval)
|
||||
out_aval = core.AbstractTuple((carry_aval, y_aval))
|
||||
jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
|
||||
length = leading_dim_size(xs)
|
||||
return scan_initial_p.bind(core.pack(consts), init, xs,
|
||||
length=length, jaxpr=jaxpr)
|
||||
forward=True, length=length, jaxpr=jaxpr)
|
||||
|
||||
|
||||
def _scan_initial_impl(consts, init, xs, length, jaxpr):
|
||||
def _scan_initial_impl(consts, init, xs, forward, length, jaxpr):
|
||||
_, _, x_aval = jaxpr.in_avals
|
||||
_, y_aval = jaxpr.out_aval
|
||||
ys_aval = promote_aval_rank(length, y_aval)
|
||||
|
||||
def body_fun(i, vals):
|
||||
idx = i if forward else length - i - 1
|
||||
carry, ys = vals
|
||||
x = index_arrays(i, x_aval, xs)
|
||||
x = index_arrays(idx, x_aval, xs)
|
||||
carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x)
|
||||
ys_out = update_arrays(i, y_aval, ys, y)
|
||||
ys_out = update_arrays(idx, y_aval, ys, y)
|
||||
return (carry_out, ys_out)
|
||||
|
||||
ys_init = empty_arrays(ys_aval)
|
||||
@ -189,7 +191,7 @@ def _scan_initial_impl(consts, init, xs, length, jaxpr):
|
||||
return core.pack((carry, ys))
|
||||
|
||||
|
||||
def _scan_initial_jvp(primals, tangents, length, jaxpr):
|
||||
def _scan_initial_jvp(primals, tangents, forward, length, jaxpr):
|
||||
consts, init, xs = primals
|
||||
consts_dot, init_dot, xs_dot = tangents
|
||||
consts_aval, carry_aval, x_aval = jaxpr.in_avals
|
||||
@ -220,7 +222,8 @@ def _scan_initial_jvp(primals, tangents, length, jaxpr):
|
||||
xs_dual = core.pack((xs, nonzero_xs_dot))
|
||||
|
||||
carry_out_dual, ys_dual = scan_initial_p.bind(
|
||||
consts_dual, init_dual, xs_dual, length=length, jaxpr=jaxpr_jvp)
|
||||
consts_dual, init_dual, xs_dual,
|
||||
forward=forward, length=length, jaxpr=jaxpr_jvp)
|
||||
|
||||
ys, ys_dot = ys_dual
|
||||
ys_dot = ad.put_zeros(ad.TangentTuple, where_ys_zeros, ys_dot)
|
||||
@ -257,6 +260,8 @@ def binary_lattice_join(a, b):
|
||||
def _scan_initial_partial_eval(trace, *tracers, **kwargs):
|
||||
jaxpr = kwargs.pop('jaxpr')
|
||||
length = kwargs.pop('length')
|
||||
forward = kwargs.pop('forward')
|
||||
assert not kwargs
|
||||
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
||||
fc_consts, fc_init, fc_xs = map(is_const, in_pvs)
|
||||
|
||||
@ -278,20 +283,13 @@ def _scan_initial_partial_eval(trace, *tracers, **kwargs):
|
||||
out_pv = _put_known_pvs(fc_out, jaxpr.out_aval)
|
||||
|
||||
out_carry, (ys, residuals) = scan_initial_p.bind(
|
||||
*in_consts, length=length, jaxpr=jaxpr_1)
|
||||
*in_consts, forward=forward, length=length, jaxpr=jaxpr_1)
|
||||
out_const = core.pack((out_carry, ys))
|
||||
residual_tracers = core.pack(map(trace.new_instantiated_const, residuals))
|
||||
residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
|
||||
d, c, a = lifted_tracers
|
||||
new_tracers = (d, c, core.pack((a, residual_tracers))) # TODO nonlin pack
|
||||
# TODO adapt scan to
|
||||
# option #1:
|
||||
# scan :: (d -> c -> a -> b) -> d -> c -> [a] -> [b]
|
||||
# scan :: (d -> c -> a -> alin -> b) -> d -> c -> [a] -> [alin] -> [b]
|
||||
# option #2:
|
||||
# extend jaxpr language to have destructuring tuples of variables in invars
|
||||
# b = g(a, (x, a))
|
||||
eqn = core.JaxprEqn(new_tracers, None, scan_initial_p, (), False,
|
||||
dict(length=length, jaxpr=jaxpr_2))
|
||||
new_tracers = (d, c, (a, residuals_tracer))
|
||||
eqn = core.JaxprEqn(new_tracers, None, scan_initial_p, (), True, False,
|
||||
dict(forward=forward, length=length, jaxpr=jaxpr_2))
|
||||
return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
|
||||
|
||||
def _lift_tracer(trace, tracer, is_const):
|
||||
@ -316,25 +314,57 @@ def _put_known_pvs(is_known, aval):
|
||||
return pe.JaxprTracerTuple(map(_put_known_pvs, is_known, aval))
|
||||
|
||||
|
||||
def _scan_initial_transpose(ct, consts, init, xs, length, jaxpr):
|
||||
def _scan_initial_transpose(ct, consts, init, xs, forward, length, jaxpr):
|
||||
assert consts is None and init is None
|
||||
import ipdb; ipdb.set_trace() # TODO but xs is also None!
|
||||
assert type(xs) is tuple
|
||||
a, res = xs
|
||||
assert a is None and res is not None
|
||||
|
||||
# jaxpr :: d -> c -> (a, res) -> (c, b)
|
||||
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
|
||||
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
|
||||
jaxpr_lifted = _move_res_and_uncurry(jaxpr)
|
||||
jaxpr_lifted = _move_res_and_uncurry(jaxpr, _scan_newvar)
|
||||
import ipdb; ipdb.set_trace()
|
||||
jaxpr_lifted_trans = transpose_jaxpr2(jaxpr_lifted)
|
||||
jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)
|
||||
|
||||
assert False
|
||||
# c_bar, bs_bar = ct
|
||||
# d_bar = zeros
|
||||
ct_c, ct_bs = ct
|
||||
carry_ct = core.pack((ct_c, ad_util.zeros_like_aval(jaxpr.in_avals[0])))
|
||||
|
||||
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
|
||||
# scan_p.bind :: (d -> c -> a -> (c, b)) -> d -> c -> [a] -> (c, [b])
|
||||
scan_initial_p.bind(
|
||||
core.unit, carry_ct, core.pack((ct_bs, res)),
|
||||
forward=not forward, length=length, jaxpr=jaxpr_trans)
|
||||
|
||||
def _move_res_and_uncurry(jaxpr, newvar):
|
||||
# jaxpr :: d -> c -> (a, res) -> (c, b)
|
||||
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
|
||||
assert len(jaxpr.in_avals) == 3
|
||||
assert type(jaxpr.in_avals[2]) is core.AbstractTuple
|
||||
|
||||
d_aval, c_aval, (a_aval, res_aval) = jaxpr.in_avals
|
||||
in_avals = [res_aval, core.AbstractTuple((d_aval, c_aval, a_aval))]
|
||||
|
||||
d, c, a_res = jaxpr.jaxpr.invars
|
||||
a = newvar()
|
||||
res = newvar()
|
||||
d_c_a = newvar()
|
||||
invars = [res, d_c_a]
|
||||
eqns = (
|
||||
[pe._unpack_eqn(d_c_a, [d, c, a]),
|
||||
pe._pack_eqn([a, res], a_res)]
|
||||
+ list(jaxpr.jaxpr.eqns))
|
||||
|
||||
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, jaxpr.jaxpr.freevars,
|
||||
invars, jaxpr.jaxpr.outvar, eqns)
|
||||
core.skip_checks or core.check_jaxpr(new_jaxpr)
|
||||
return core.TypedJaxpr(new_jaxpr, jaxpr.literals, in_avals, jaxpr.out_aval)
|
||||
|
||||
# return scan_initial_p.bind(core.unit,
|
||||
|
||||
# transpose_jaxpr :: (res -> a -> b) -> (res -> CT b -> CT a)
|
||||
# TODO either top-level restructure, or else munge somehow
|
||||
def transpose_jaxpr2(jaxpr):
|
||||
assert len(jaxpr.in_avals) == 2
|
||||
def transposed(res, b_bar):
|
||||
|
@ -145,6 +145,9 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
|
||||
def read_cotangent(v):
|
||||
return ct_env.get(v, zero)
|
||||
|
||||
def read_primal(v):
|
||||
return primal_env.get(v)
|
||||
|
||||
primal_env = {v: val for v, val in zip(jaxpr.freevars, freevar_vals)
|
||||
if val is not None}
|
||||
primal_env.update(zip(jaxpr.constvars, consts))
|
||||
@ -155,12 +158,16 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
ct_in = TangentTuple(cts_in) if eqn.destructure else cts_in[0]
|
||||
invals = map(primal_env.get, eqn.invars)
|
||||
if not eqn.restructure:
|
||||
invals = map(read_primal, eqn.invars)
|
||||
else:
|
||||
invals = [tuple(map(read_primal, v)) if type(v) is tuple
|
||||
else read_primal(v) for v in eqn.invars]
|
||||
if eqn.bound_subjaxprs:
|
||||
subjaxprs, sub_consts, sub_freevar_vals = unzip3([
|
||||
(subjaxpr,
|
||||
map(primal_env.get, const_vars),
|
||||
map(primal_env.get, bound_vars))
|
||||
map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars))
|
||||
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
|
||||
cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
|
||||
eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
|
||||
|
@ -74,11 +74,11 @@ class JaxprTrace(Trace):
|
||||
avals = [t.aval for t in tracers]
|
||||
out_aval = primitive.abstract_eval(*avals, **params)
|
||||
partial_val = PartialVal((out_aval, unit))
|
||||
eqn = JaxprEqn(tracers, None, primitive, (), False, params)
|
||||
eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
|
||||
return JaxprTracer(self, partial_val, eqn)
|
||||
|
||||
def pack(self, tracers):
|
||||
eqn = JaxprEqn(tracers, None, core.pack_p, (), False, {})
|
||||
eqn = JaxprEqn(tracers, None, core.pack_p, (), False, False, {})
|
||||
pval = pack_pvals([t.pval for t in tracers])
|
||||
return JaxprTracer(self, pval, eqn)
|
||||
|
||||
@ -92,7 +92,8 @@ class JaxprTrace(Trace):
|
||||
const_tracers = map(self.new_instantiated_const, consts)
|
||||
env_tracers = map(self.full_raise, env)
|
||||
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
||||
eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,), False, params)
|
||||
eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,),
|
||||
False, False, params)
|
||||
return JaxprTracer(self, PartialVal((out_pv, out_pv_const)), eqn)
|
||||
|
||||
def process_map(self, call_primitive, f, tracers, params):
|
||||
@ -109,7 +110,8 @@ class JaxprTrace(Trace):
|
||||
jaxpr_converted.invars = list(it.chain(jaxpr.constvars, jaxpr.invars))
|
||||
invars = tuple(it.chain(const_tracers, tracers))
|
||||
bound_subjaxpr = (jaxpr_converted, (), env)
|
||||
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,), False, params)
|
||||
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,),
|
||||
False, False, params)
|
||||
return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn)
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracer):
|
||||
@ -124,7 +126,8 @@ class JaxprTrace(Trace):
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
||||
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, {})
|
||||
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,),
|
||||
False, False, {})
|
||||
return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn)
|
||||
|
||||
return out, todo
|
||||
@ -149,7 +152,7 @@ def scan_process_primitive(trace, consts, init, xs, avals, jaxpr):
|
||||
avals=avals1, jaxpr=jaxpr1)
|
||||
|
||||
params_out = {'avals' : avals2, 'jaxpr' : jaxpr2}
|
||||
eqn = JaxprEqn([consts, init, xs], None, scan_p, (), False, params_out)
|
||||
eqn = JaxprEqn([consts, init, xs], None, scan_p, (), False, False, params_out)
|
||||
return JaxprTracer(trace, PartialVal((ans, ans_pv)), )
|
||||
|
||||
# in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
||||
@ -278,7 +281,7 @@ class JaxprTracer(Tracer):
|
||||
if isinstance(pv, AbstractValue):
|
||||
const = [unit for _ in range(n)]
|
||||
key = object()
|
||||
eqn = JaxprEqn([self], [None]*n, core.identity_p, (), True, {})
|
||||
eqn = JaxprEqn([self], [None]*n, core.identity_p, (), False, True, {})
|
||||
def child_tracer(i, pval, c):
|
||||
d = Destructuring(i, eqn, key)
|
||||
return JaxprTracer(self.trace, PartialVal((pval, c)), d).full_lower()
|
||||
@ -416,12 +419,16 @@ ConstVar = namedtuple('ConstVar', ['val'])
|
||||
LambdaBinding = namedtuple('LambdaBinding', [])
|
||||
|
||||
def eqn_tracer_to_var(var, outvars, eqn):
|
||||
invars, _, primitive, bound_subjaxprs, destructure, params = eqn
|
||||
invars = map(var, invars)
|
||||
invars, _, primitive, bound_subjaxprs, restructure, destructure, params = eqn
|
||||
if not restructure:
|
||||
invars = map(var, invars)
|
||||
else:
|
||||
invars = [tuple(map(var, v)) if type(v) is tuple else var(v)
|
||||
for v in invars]
|
||||
new_bound_subjaxprs = [(j, map(var, c), map(var, f))
|
||||
for j, c, f in bound_subjaxprs]
|
||||
return JaxprEqn(invars, outvars, primitive,
|
||||
new_bound_subjaxprs, destructure, params)
|
||||
new_bound_subjaxprs, restructure, destructure, params)
|
||||
|
||||
|
||||
def tracers_to_jaxpr(in_tracers, out_tracer):
|
||||
@ -486,36 +493,18 @@ class Var(object):
|
||||
|
||||
def eqn_parents(eqn):
|
||||
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
|
||||
return list(it.chain(eqn.invars, *subjaxpr_tracers))
|
||||
if not eqn.restructure:
|
||||
return list(it.chain(eqn.invars, *subjaxpr_tracers))
|
||||
else:
|
||||
invars = []
|
||||
for v in eqn.invars:
|
||||
if type(v) is tuple:
|
||||
invars.extend(v)
|
||||
else:
|
||||
invars.append(v)
|
||||
return list(it.chain(invars, *subjaxpr_tracers))
|
||||
|
||||
|
||||
def eval_jaxpr_raw(jaxpr, consts, freevar_vals, *args):
|
||||
assert all(map(core.valid_jaxtype, consts))
|
||||
assert all(map(core.valid_jaxtype, freevar_vals))
|
||||
assert all(map(core.valid_jaxtype, args))
|
||||
|
||||
def read(v):
|
||||
return env[v]
|
||||
|
||||
def write(v, val):
|
||||
env[v] = val
|
||||
|
||||
env = {}
|
||||
write(unitvar, unit)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
map(write, jaxpr.freevars, freevar_vals)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals = map(read, eqn.invars)
|
||||
subfuns = [partial(core.eval_jaxpr, subjaxpr, map(read, const_bindings),
|
||||
map(read, freevar_bindings))
|
||||
for subjaxpr, const_bindings, freevar_bindings
|
||||
in eqn.bound_subjaxprs]
|
||||
ans = eqn.primitive.impl(*(subfuns + in_vals), **eqn.params) # not bind!
|
||||
outvals = list(ans) if eqn.destructure else [ans]
|
||||
map(write, eqn.outvars, outvals)
|
||||
return read(jaxpr.outvar)
|
||||
|
||||
def compiled_call_impl(fun, *args, **kwargs):
|
||||
with new_master(JaxprTrace, True) as master:
|
||||
pvals = map(abstractify, args)
|
||||
@ -616,10 +605,10 @@ def _closure_convert_jaxpr(jaxpr, newvar):
|
||||
return lifted_jaxpr
|
||||
|
||||
def _unpack_eqn(invar, outvars):
|
||||
return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
|
||||
return core.JaxprEqn([invar], outvars, core.identity_p, (), False, True, {})
|
||||
|
||||
def _pack_eqn(invars, outvar):
|
||||
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
|
||||
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, False, {})
|
||||
|
||||
|
||||
def partial_eval_jaxpr2(jaxpr, first_components):
|
||||
|
@ -222,7 +222,7 @@ def replicated_comp(jaxpr, ax_env, const_vals, freevar_shapes, *arg_shapes):
|
||||
map(write, all_freevars, map(c.ParameterWithShape, freevar_shapes))
|
||||
map(write, jaxpr.invars, map(c.ParameterWithShape, arg_shapes))
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
in_nodes = map(read, eqn.invars) # TODO
|
||||
if eqn.primitive in parallel_translation_rules:
|
||||
name = eqn.params['axis_name']
|
||||
params = {k: eqn.params[k] for k in eqn.params if k != 'axis_name'}
|
||||
|
@ -163,7 +163,7 @@ def jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes):
|
||||
map(write, all_freevars, map(c.ParameterWithShape, freevar_shapes))
|
||||
map(write, jaxpr.invars, map(c.ParameterWithShape, arg_shapes))
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
in_nodes = map(read, eqn.invars) # TODO
|
||||
in_shapes = map(c.GetShape, in_nodes)
|
||||
subcs = [
|
||||
jaxpr_computation(
|
||||
|
@ -3819,10 +3819,10 @@ xla.translations[while_p] = _while_loop_translation_rule
|
||||
batching.primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
|
||||
def _unpack_eqn(invar, outvars):
|
||||
return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
|
||||
return core.JaxprEqn([invar], outvars, core.identity_p, (), False, True, {})
|
||||
|
||||
def _pack_eqn(invars, outvar):
|
||||
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
|
||||
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, False, {})
|
||||
|
||||
|
||||
def _cond_abstract_eval(pred, true_op, true_consts, false_op, false_consts,
|
||||
|
Loading…
x
Reference in New Issue
Block a user