Add linearization rule for pjit_p

This commit is contained in:
Dougal 2024-11-22 14:15:46 -08:00
parent 73fa0f48cb
commit b1d1dcf607
5 changed files with 112 additions and 27 deletions

View File

@ -105,22 +105,56 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
store.store(aux_primals)
return out_primals, out_tangents
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
return core.Jaxpr(constvars=(),
invars=jaxpr.invars + jaxpr.constvars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=dbg)
def linearize_jaxpr(jaxpr, nonzeros):
primal_trace = pe.DynamicJaxprTrace()
tangent_trace = pe.DynamicJaxprTrace()
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
def new_arg(primal_aval, nz):
primal = primal_trace.new_arg(primal_aval)
tangent_aval = primal_aval.to_tangent_aval()
tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval)
return LinearizeTracer(lin_trace, primal, tangent)
tracers = [new_arg(v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)]
with core.set_current_trace(lin_trace):
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz]
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
del attrs_tracked # TODO: attrs
residuals_and_primals = (*tangent_consts, *out_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
del attrs_tracked # TODO: attrs
return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr
def direct_linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
assert not has_aux
with core.take_current_trace() as parent_trace:
frame = pe.JaxprStackFrame()
tangent_trace = pe.DynamicJaxprTrace(frame)
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
tag = core.TraceTag()
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
with core.set_current_trace(linearize_trace):
ans = traceable.call_wrapped(*tracers)
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
del attrs_tracked # TODO: attrs
return out_primals, out_tangents_pvals, jaxpr, consts
@ -469,8 +503,8 @@ call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
class LinearizeTrace(Trace):
def __init__(self, parent_trace, tangent_trace, tag):
self.tag = tag
def __init__(self, parent_trace, tangent_trace, tag=None):
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.tangent_trace = tangent_trace
@ -509,18 +543,20 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
return primal
def fallback_linearize_rule(prim, _, *args, **kwargs):
assert not prim.multiple_results
def call_prim(*args_):
return prim.bind(*args_, **kwargs)
return [prim.bind(*args_, **kwargs)]
with config.use_direct_linearize(False):
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
lu.wrap_init(call_prim), *args, **kwargs)
def linearized(residuals, *tangents):
tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents))
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
for pval in out_tangents_pvals]
assert next(tangents_out, None) is None
return full_out
return out_primals, [True for _ in out_primals], consts, linearized
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
return out_tangent
return out_primal, True, consts, linearized
class LinearizeTracer(Tracer):
__slots__ = ['primal', 'tangent']

View File

@ -1575,6 +1575,7 @@ class DynamicJaxprTracer(core.Tracer):
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
@ -1805,8 +1806,8 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
def __init__(self, frame):
self.frame = frame
def __init__(self):
self.frame = JaxprStackFrame()
def invalidate(self):
# avoid cyclic refs
@ -2068,6 +2069,9 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
return self.frame.to_jaxpr(self, out_tracers)
custom_staging_rules: dict[Primitive, Callable] = {}
@ -2166,10 +2170,8 @@ def trace_to_jaxpr_dynamic(
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
frame.debug_info = debug_info
trace = DynamicJaxprTrace(frame)
trace = DynamicJaxprTrace()
trace.frame.debug_info = debug_info
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
@ -2177,8 +2179,8 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers)
del trace, fun, frame, in_tracers, out_tracers, ans
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
@ -2188,7 +2190,7 @@ def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
) -> tuple[Jaxpr, OutputType, list[Any]]:
trace = DynamicJaxprTrace(JaxprStackFrame())
trace = DynamicJaxprTrace()
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
trace.frame.debug_info = debug_info
in_avals, keep_inputs = unzip2(fun.in_type)

View File

@ -2400,9 +2400,10 @@ def _sin_lowering(ctx, x):
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
def _sin_p_lin(_, x):
def _sin_p_lin(nzs, x):
nz, = nzs
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_))
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))

View File

@ -2107,6 +2107,52 @@ def _pjit_jvp(primals_in, tangents_in,
ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_linearization(nzs, *primals_in, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
# constvars will become residuals. Move them to the end of the ordinary args.
res_shardings = (UNSPECIFIED,) * num_residuals
res_layouts = (None,) * num_residuals
res_donated = (False,) * num_residuals
def tangent_fun(consts_, *tangents):
tangents_nz = _filter_zeros(nzs, tangents)
assert len(consts_) == num_residuals
return pjit_p.bind(*(*tangents_nz, *consts_),
jaxpr=tangent_jaxpr,
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
out_layouts=_filter_zeros(nzs_out, out_layouts),
resource_env=resource_env,
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
def _filter_zeros(is_nz_l, l):
return tuple(x for nz, x in zip(is_nz_l, l) if nz)
ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr,
in_shardings=in_shardings,
out_shardings=(*res_shardings, *out_shardings),
in_layouts=in_layouts,
out_layouts=(*res_layouts, *out_layouts),
resource_env=resource_env,
donated_invars=donated_invars,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
residuals_ans, primal_ans = split_list(ans, [num_residuals])
return primal_ans, nzs_out, residuals_ans, tangent_fun
ad.primitive_linearizations[pjit_p] = _pjit_linearization
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,

View File

@ -4818,7 +4818,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(ans1, ans2)
def sin_of_sin(x):
return lax.sin(lax.sin(x))
return lax.sin(jax.jit(lax.sin)(x))
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))