Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph.
PiperOrigin-RevId: 632318740
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
Building on enrique's work, this CL refactors the emit_pipeline abstraction:
1) factors out the VMEM double-buffering bookkeeping into a helper class.
2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper
while allowing manual overrides, callbacks don't control scheduling anymore, rather we have
explicit loop scheduling.
3) minimize callbacks and simplify the "defaults" for fusing pipelines together.
Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and
matmul-RSs are included in new tests.
PiperOrigin-RevId: 631865641