mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Minor cleanup to pallas_call_p
JVP code.
PiperOrigin-RevId: 555862179
This commit is contained in:
parent
7e1278c040
commit
deed8b71b1
@ -161,9 +161,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
||||
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
||||
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
||||
jvp_which_linear = which_linear + (True,) * len(tangents)
|
||||
jvp_inshapes = (*in_shapes, *in_shapes)
|
||||
jvp_outshapes = (*out_shapes, *out_shapes)
|
||||
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
|
||||
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
|
||||
# `Ref`s that are read from followed by output `Ref`s that are written to.
|
||||
@ -171,29 +168,30 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
# jaxpr that has tangents following primals. In order for this jaxpr to be
|
||||
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
||||
# the jaxpr's invars.
|
||||
logical_primals, logical_tangents = split_list(
|
||||
jvp_jaxpr.invars, [len(primals) + len(out_shapes)])
|
||||
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
|
||||
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
|
||||
new_grid_mapping = grid_mapping.replace(block_mappings=new_bms)
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
|
||||
*logical_tangent_inputs,
|
||||
*logical_primal_outputs,
|
||||
*logical_tangent_outputs])
|
||||
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
|
||||
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
|
||||
)
|
||||
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
||||
# TODO(sharadmv): Fix state effect tracking after invar switch.
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=invars)
|
||||
if debug:
|
||||
print(jvp_jaxpr)
|
||||
out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*primals,
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
name=f"{name}_jvp",
|
||||
in_shapes=jvp_inshapes,
|
||||
out_shapes=jvp_outshapes,
|
||||
grid_mapping=new_grid_mapping,
|
||||
which_linear=jvp_which_linear,
|
||||
in_shapes=(*in_shapes, *in_shapes),
|
||||
out_shapes=(*out_shapes, *out_shapes),
|
||||
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
|
||||
which_linear=which_linear + (True,) * len(tangents),
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
**compiler_params)
|
||||
**compiler_params,
|
||||
)
|
||||
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
Loading…
x
Reference in New Issue
Block a user