This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates.
Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU.
PiperOrigin-RevId: 411632993
https://github.com/google/jax/pull/8606 introduced a runtime error where as a consequence of the move, a reference to `slice` became a reference to the builtin slice operator instead of `lax.slice`.
After fixing that and while added a test, I noticed that the gradient was wrong before: we should have been slicing the result, not the operand in the transpose rule's handling of base dilation.
Also enable some TPU tests that now pass since we have variadic reduce-window support on TPU.
PiperOrigin-RevId: 411579650
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
When structures are very large, users can end up with pages and pages describing the two structures, and finding exactly where they differ can be tricky. This change makes these differences more obvious.
PiperOrigin-RevId: 411131921
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts
This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.
But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.
In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.
These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.