[pallas] Minor cleanup to pallas_call_p JVP code.

PiperOrigin-RevId: 555862179
This commit is contained in:
Chris Jones 2023-08-11 02:29:54 -07:00 committed by jax authors
parent 7e1278c040
commit deed8b71b1

View File

@ -161,9 +161,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
jvp_which_linear = which_linear + (True,) * len(tangents)
jvp_inshapes = (*in_shapes, *in_shapes)
jvp_outshapes = (*out_shapes, *out_shapes)
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
# `Ref`s that are read from followed by output `Ref`s that are written to.
@ -171,29 +168,30 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
# jaxpr that has tangents following primals. In order for this jaxpr to be
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
# the jaxpr's invars.
logical_primals, logical_tangents = split_list(
jvp_jaxpr.invars, [len(primals) + len(out_shapes)])
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
new_grid_mapping = grid_mapping.replace(block_mappings=new_bms)
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
*logical_tangent_inputs,
*logical_primal_outputs,
*logical_tangent_outputs])
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
)
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
# TODO(sharadmv): Fix state effect tracking after invar switch.
jvp_jaxpr = jvp_jaxpr.replace(invars=invars)
if debug:
print(jvp_jaxpr)
out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
out_flat = pallas_call_p.bind(
*primals,
*tangents,
jaxpr=jvp_jaxpr,
name=f"{name}_jvp",
in_shapes=jvp_inshapes,
out_shapes=jvp_outshapes,
grid_mapping=new_grid_mapping,
which_linear=jvp_which_linear,
in_shapes=(*in_shapes, *in_shapes),
out_shapes=(*out_shapes, *out_shapes),
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
which_linear=which_linear + (True,) * len(tangents),
interpret=interpret,
debug=debug,
input_output_aliases=(),
**compiler_params)
**compiler_params,
)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule