mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Even more linearize fixes
This commit is contained in:
parent
4f106b8a27
commit
96769f96c2
@ -86,14 +86,15 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents):
|
||||
def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace()
|
||||
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval())
|
||||
for (p, nz) in zip(primals, nzs_in) if nz]
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
tracers = [LinearizeTracer(linearize_trace, p,
|
||||
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
|
||||
if nz else p
|
||||
for p, nz in zip(primals, nzs_in)]
|
||||
with core.set_current_trace(linearize_trace):
|
||||
ans = _f(*tracers)
|
||||
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
|
||||
nzs_out = [type(t) is not Zero for t in out_tangents]
|
||||
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
||||
out_tangents = [t for t, nz in zip(out_tangents, nzs_out) if nz]
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
@ -135,6 +136,10 @@ def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
effects=jaxpr.effects, debug_info=dbg)
|
||||
|
||||
def linearize_jaxpr(jaxpr, nonzeros):
|
||||
return _linearize_jaxpr(jaxpr, tuple(nonzeros))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _linearize_jaxpr(jaxpr, nonzeros):
|
||||
primal_trace = pe.DynamicJaxprTrace()
|
||||
tangent_trace = pe.DynamicJaxprTrace()
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
@ -154,11 +159,13 @@ def linearize_jaxpr(jaxpr, nonzeros):
|
||||
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
|
||||
for (nz, t) in zip(nzs_out, out_tangents) if nz]
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
tangent_trace.invalidate()
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
residuals_and_primals = (*tangent_consts, *out_primals)
|
||||
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
|
||||
primal_trace.invalidate()
|
||||
num_residuals = len(tangent_consts)
|
||||
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
|
||||
if attrs_tracked:
|
||||
@ -187,6 +194,7 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
|
||||
out_tangents = map(instantiate_zeros, out_tangents)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
tangent_trace.invalidate()
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
@ -551,6 +559,7 @@ def _primal_tangent_shapes_match(primal, tangent):
|
||||
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
||||
|
||||
call_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
call_linearize_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
|
||||
# -------------------- Linearize trace --------------------
|
||||
@ -637,13 +646,42 @@ class LinearizeTrace(Trace):
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
nzs_in = [type(t) is not Zero for t in tangents]
|
||||
nzs_in = tuple(type(t) is not Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in)
|
||||
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), params)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
def new_out_axes_thunk():
|
||||
num_residuals, _, _ = linearize_outs_thunk()
|
||||
out_axes = params['out_axes_thunk']()
|
||||
return (*(0 for _ in range(num_residuals)), *out_axes)
|
||||
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
||||
else:
|
||||
primal_params = params
|
||||
|
||||
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
|
||||
num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
in_axes = params['in_axes']
|
||||
out_axes = params['out_axes_thunk']()
|
||||
residual_avals = map(get_aval, residuals)
|
||||
new_in_axes = (*(0 for _ in residual_avals),
|
||||
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
|
||||
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
|
||||
# NOTE: This assumes that the output tangents being zero is a
|
||||
# deterministic function of which input tangents were zero.
|
||||
@as_hashable_function(closure=(new_out_axes))
|
||||
def new_out_axes_thunk():
|
||||
return new_out_axes
|
||||
params = dict(params,
|
||||
in_axes=new_in_axes,
|
||||
out_axes_thunk=new_out_axes_thunk)
|
||||
|
||||
update_params = call_linearize_param_updaters.get(call_primitive)
|
||||
new_params = update_params(params, residual_avals, nzs_in) if update_params else params
|
||||
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
@ -651,12 +689,17 @@ class LinearizeTrace(Trace):
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = call_primitive.bind_with_trace(
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), params)
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out)
|
||||
|
||||
# The only difference between process_map and process_call is that
|
||||
# the `in_axes` and `out_axes_thunk` params must be updated;
|
||||
# that's handled in process_call.
|
||||
process_map = process_call
|
||||
|
||||
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
|
||||
if is_nonzero:
|
||||
assert not type(tangent) is Zero
|
||||
@ -692,8 +735,8 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros,
|
||||
else:
|
||||
zero_type = Zero
|
||||
|
||||
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
|
||||
for aval, nz in zip(tangent_avals, nonzeros)]
|
||||
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
|
||||
for aval, nz in zip(tangent_avals, nonzeros))
|
||||
with core.set_current_trace(trace):
|
||||
out_primals, out_tangents = jvp(primals, tangent_args, **params)
|
||||
|
||||
|
@ -1396,6 +1396,12 @@ def xla_call_jvp_update_params(params, nz_tangents):
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
|
||||
def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
|
||||
donated_invars_prev = params['donated_invars']
|
||||
donated_invars = (*(False for _ in residual_avals),
|
||||
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
|
||||
return dict(params, donated_invars=donated_invars)
|
||||
|
||||
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||
@ -1411,6 +1417,7 @@ pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
||||
res_aval=_pmap_partial_eval_custom_res_maker)
|
||||
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
||||
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
|
||||
ad.call_linearize_param_updaters[xla_pmap_p] = _xla_call_linearize_update_params
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params
|
||||
|
||||
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
||||
|
@ -148,6 +148,8 @@ def _getattr_jvp(trace, obj, attr):
|
||||
return getattr(obj, attr)
|
||||
ad.JVPTrace.process_getattr = _getattr_jvp
|
||||
|
||||
ad.LinearizeTrace.process_setattr = _setattr_jvp
|
||||
ad.LinearizeTrace.process_getattr = _getattr_jvp
|
||||
|
||||
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
attr_primals = [jax_getattr(o, a) for o, a in attrs]
|
||||
|
@ -1531,6 +1531,55 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
return pe.merge_lists(out_knowns, out_tracers, out_consts)
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
|
||||
def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
|
||||
all_names = _all_mesh_names_except_spmd(mesh, trace)
|
||||
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
def primal_out_names_thunk():
|
||||
num_residuals, _, _ = linearize_outs_thunk()
|
||||
out_names = out_names_thunk()
|
||||
return (*({0: all_names} for _ in range(num_residuals)), *out_names)
|
||||
primal_params = dict(
|
||||
mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
all_primal_results = shard_map_p.bind_with_trace(
|
||||
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
|
||||
num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
residual_avals = map(core.get_aval, residuals)
|
||||
out_names = out_names_thunk()
|
||||
new_in_names = (*({0: all_names} for _ in residual_avals),
|
||||
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
|
||||
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
|
||||
@as_hashable_function(closure=(new_out_names))
|
||||
def tangent_out_names_thunk():
|
||||
return new_out_names
|
||||
tangent_params = dict(
|
||||
mesh=mesh, in_names=new_in_names,
|
||||
out_names_thunk=tangent_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
|
||||
(lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out)
|
||||
ad.LinearizeTrace.process_shard_map = _shard_map_linearize
|
||||
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals(f, *args, **kwargs):
|
||||
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user