We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.
PiperOrigin-RevId: 586311031
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.
I'm planning to improve the XLA:CPU lowering in a subsequent change.
PiperOrigin-RevId: 586291911
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.
We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:
`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.
We also rename the `export.poly_specs` to `export.arg_specs`.
JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd).
The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 585790912
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.
Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.
Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
The flag ends up in libtpu in the OSS build, so we will need to find a different
mechanism for this. This is a quick fix to get things working again.
PiperOrigin-RevId: 584835560
Also check all the jax2tf tests to ensure that each one
has at least one TPU configuration marked for TAP continuous.
Without this we will only notice failures on TPU post submit, as it was
the case here.
PiperOrigin-RevId: 584253387
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.
Testing: test workload.
PiperOrigin-RevId: 584148686