Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.
PiperOrigin-RevId: 688488660
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.
Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).
When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.
The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.
Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.
Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.
PiperOrigin-RevId: 687350270
This is to avoid issues when small buffers throw off the alignment for large TMA and WGMMA
operands. We should make this more refined in the future, but this should be enough for now.
PiperOrigin-RevId: 687264994
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.
PiperOrigin-RevId: 687003464