mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add linearization rule for pjit_p
This commit is contained in:
parent
73fa0f48cb
commit
b1d1dcf607
@ -105,22 +105,56 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
|
||||
store.store(aux_primals)
|
||||
return out_primals, out_tangents
|
||||
|
||||
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
|
||||
return core.Jaxpr(constvars=(),
|
||||
invars=jaxpr.invars + jaxpr.constvars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects, debug_info=dbg)
|
||||
|
||||
def linearize_jaxpr(jaxpr, nonzeros):
|
||||
primal_trace = pe.DynamicJaxprTrace()
|
||||
tangent_trace = pe.DynamicJaxprTrace()
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
|
||||
def new_arg(primal_aval, nz):
|
||||
primal = primal_trace.new_arg(primal_aval)
|
||||
tangent_aval = primal_aval.to_tangent_aval()
|
||||
tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval)
|
||||
return LinearizeTracer(lin_trace, primal, tangent)
|
||||
|
||||
tracers = [new_arg(v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)]
|
||||
with core.set_current_trace(lin_trace):
|
||||
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
|
||||
|
||||
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
|
||||
nzs_out = [type(t) is not Zero for t in out_tangents]
|
||||
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
|
||||
for (nz, t) in zip(nzs_out, out_tangents) if nz]
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
del attrs_tracked # TODO: attrs
|
||||
residuals_and_primals = (*tangent_consts, *out_primals)
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
|
||||
num_residuals = len(tangent_consts)
|
||||
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
|
||||
del attrs_tracked # TODO: attrs
|
||||
return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr
|
||||
|
||||
def direct_linearize(traceable, *primals, **kwargs):
|
||||
has_aux = kwargs.pop('has_aux', False)
|
||||
assert not has_aux
|
||||
with core.take_current_trace() as parent_trace:
|
||||
frame = pe.JaxprStackFrame()
|
||||
tangent_trace = pe.DynamicJaxprTrace(frame)
|
||||
tangent_trace = pe.DynamicJaxprTrace()
|
||||
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
|
||||
tag = core.TraceTag()
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
with core.set_current_trace(linearize_trace):
|
||||
ans = traceable.call_wrapped(*tracers)
|
||||
|
||||
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
|
||||
jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
|
||||
del attrs_tracked # TODO: attrs
|
||||
return out_primals, out_tangents_pvals, jaxpr, consts
|
||||
@ -469,8 +503,8 @@ call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
|
||||
class LinearizeTrace(Trace):
|
||||
|
||||
def __init__(self, parent_trace, tangent_trace, tag):
|
||||
self.tag = tag
|
||||
def __init__(self, parent_trace, tangent_trace, tag=None):
|
||||
self.tag = core.TraceTag() if tag is None else tag
|
||||
self.parent_trace = parent_trace
|
||||
self.tangent_trace = tangent_trace
|
||||
|
||||
@ -509,18 +543,20 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
|
||||
return primal
|
||||
|
||||
def fallback_linearize_rule(prim, _, *args, **kwargs):
|
||||
assert not prim.multiple_results
|
||||
|
||||
def call_prim(*args_):
|
||||
return prim.bind(*args_, **kwargs)
|
||||
return [prim.bind(*args_, **kwargs)]
|
||||
|
||||
with config.use_direct_linearize(False):
|
||||
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
|
||||
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
|
||||
lu.wrap_init(call_prim), *args, **kwargs)
|
||||
|
||||
def linearized(residuals, *tangents):
|
||||
tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents))
|
||||
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
|
||||
for pval in out_tangents_pvals]
|
||||
assert next(tangents_out, None) is None
|
||||
return full_out
|
||||
return out_primals, [True for _ in out_primals], consts, linearized
|
||||
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
|
||||
return out_tangent
|
||||
|
||||
return out_primal, True, consts, linearized
|
||||
|
||||
class LinearizeTracer(Tracer):
|
||||
__slots__ = ['primal', 'tangent']
|
||||
|
@ -1575,6 +1575,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
|
||||
return self if val is None else get_referent(val)
|
||||
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return x.aval
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
@ -1805,8 +1806,8 @@ def _inline_literals(
|
||||
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
def __init__(self, frame):
|
||||
self.frame = frame
|
||||
def __init__(self):
|
||||
self.frame = JaxprStackFrame()
|
||||
|
||||
def invalidate(self):
|
||||
# avoid cyclic refs
|
||||
@ -2068,6 +2069,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
|
||||
return self.frame.to_jaxpr(self, out_tracers)
|
||||
|
||||
|
||||
custom_staging_rules: dict[Primitive, Callable] = {}
|
||||
|
||||
@ -2166,10 +2170,8 @@ def trace_to_jaxpr_dynamic(
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
frame = JaxprStackFrame()
|
||||
frame.debug_info = debug_info
|
||||
|
||||
trace = DynamicJaxprTrace(frame)
|
||||
trace = DynamicJaxprTrace()
|
||||
trace.frame.debug_info = debug_info
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
|
||||
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
@ -2177,8 +2179,8 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers)
|
||||
del trace, fun, frame, in_tracers, out_tracers, ans
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
@ -2188,7 +2190,7 @@ def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
|
||||
trace = DynamicJaxprTrace(JaxprStackFrame())
|
||||
trace = DynamicJaxprTrace()
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
trace.frame.debug_info = debug_info
|
||||
in_avals, keep_inputs = unzip2(fun.in_type)
|
||||
|
@ -2400,9 +2400,10 @@ def _sin_lowering(ctx, x):
|
||||
return sine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x)
|
||||
|
||||
def _sin_p_lin(_, x):
|
||||
def _sin_p_lin(nzs, x):
|
||||
nz, = nzs
|
||||
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
|
||||
return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_))
|
||||
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
|
@ -2107,6 +2107,52 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
|
||||
def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
|
||||
# constvars will become residuals. Move them to the end of the ordinary args.
|
||||
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||
res_layouts = (None,) * num_residuals
|
||||
res_donated = (False,) * num_residuals
|
||||
def tangent_fun(consts_, *tangents):
|
||||
tangents_nz = _filter_zeros(nzs, tangents)
|
||||
assert len(consts_) == num_residuals
|
||||
return pjit_p.bind(*(*tangents_nz, *consts_),
|
||||
jaxpr=tangent_jaxpr,
|
||||
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
|
||||
out_shardings=_filter_zeros(nzs_out, out_shardings),
|
||||
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
|
||||
out_layouts=_filter_zeros(nzs_out, out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
|
||||
def _filter_zeros(is_nz_l, l):
|
||||
return tuple(x for nz, x in zip(is_nz_l, l) if nz)
|
||||
|
||||
ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=(*res_shardings, *out_shardings),
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=(*res_layouts, *out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
residuals_ans, primal_ans = split_list(ans, [num_residuals])
|
||||
|
||||
return primal_ans, nzs_out, residuals_ans, tangent_fun
|
||||
|
||||
ad.primitive_linearizations[pjit_p] = _pjit_linearization
|
||||
|
||||
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
|
@ -4818,7 +4818,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(ans1, ans2)
|
||||
|
||||
def sin_of_sin(x):
|
||||
return lax.sin(lax.sin(x))
|
||||
return lax.sin(jax.jit(lax.sin)(x))
|
||||
|
||||
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user