The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
Building on enrique's work, this CL refactors the emit_pipeline abstraction:
1) factors out the VMEM double-buffering bookkeeping into a helper class.
2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper
while allowing manual overrides, callbacks don't control scheduling anymore, rather we have
explicit loop scheduling.
3) minimize callbacks and simplify the "defaults" for fusing pipelines together.
Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and
matmul-RSs are included in new tests.
PiperOrigin-RevId: 631865641
Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns.
PiperOrigin-RevId: 631469837
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.
We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.
So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.
TODO(mattjj): add more testing (maybe in follow-ups)