This PR is similar to https://github.com/jax-ml/jax/pull/24284
Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.
PiperOrigin-RevId: 688581914
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.
PiperOrigin-RevId: 688488660
Apple GPUs and Mac x86_64 is a non-existent combination.
Mac x86_64 with AMD GPU is supported.
It's a bit of a confusing situation so hard to summarize, but hopefully this is more accurate and less confusing
Fixes: #24408
These aliases were added to not break existing presubmit builds. Now that the presubmit builds have been updated, these aliases can be removed.
Also, corrects some comments.
PiperOrigin-RevId: 688096364
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.
Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).
When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.
The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.
Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.
Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.
PiperOrigin-RevId: 687350270