* Adds support for any pytree inputs to Flax Module tests and enables tests for the GNNs, which take GraphTuples as inputs.
* Adds CNN example (seems we previously forgot to add this)
PiperOrigin-RevId: 500688114
Instead of smuggling them via the jaxpr, pull it out and pass them with args. This is because consts can be tracers and that fails down the stack when lowering to mlir.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 500544141
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.
Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.
We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.
PiperOrigin-RevId: 499877003
* Previously we were creating the variables for all models, even if we did not test them. This changes ensures we only create them if we actually test the model
* It also reports when we aren't testing any models.
* Ensures we can generate markdown both from internally and externally.
* Ran all tests again and updated the g3doc with the results, which are slightly different now.
PiperOrigin-RevId: 499798630
This change enables the use of dimension polynomials wherever constaints
are used. This would arise, e.g., when tracing `lambda x: x.shape[0]`
in presence of shape polymorphism.
This won't be needed anymore once the --jax_dynamic_shapes improves
its coverage to replace shape polymorphism.
The downside of this change is that it adds a code path to Trace.full_raise.
An alternative would be to ask users to explicitly convert dimensions:
`lambda x: core.dimension_as_value(x.shape[0])`. Both of these can be
removed in the future, but the former has the advantage of being
internal to JAX.
An alternative internal change in `Trace.full_raise` would be
```
if hasattr(val, "__jax_array__"): val = val.__jax_array__()
```
but I think that using `dimension_as_value` makes it clear what
use case is addressed by this change.