This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.
PiperOrigin-RevId: 513212715
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 513086379
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 513047925
Currently, JAX is generating random 8 bit ints for bools, which usually doesn't cause any issues, but in some special cases does. One example is the HLO snapshot dumping code, which surprisingly creates unparseable protos for such inputs.
PiperOrigin-RevId: 513032802
The goal of this is to make it easier to address the out-of-bound index issue. Our current GPU logic grew somewhat organically over time, and the logic for which sub-routine is called is spread over multiple locations. This change updates the branching such that the logic for each sub-routine appears directly adjacent to its call site; the tradeoff is that other considerations (such as whether to raise a warning) have to be duplicated between the cases.
Additionally, I simplified some of the hlo operation calls to make the code easier to follow.
PiperOrigin-RevId: 513025719
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
Allow the user of native serialization to specify the platform for which
the serialization to be done. This relies on newly added support for
platform checking in XlaCallModule op (version 3).
functools.partial(jax.arrays.ArrayImpl) with the added benefit
that the new PyExecuteResults type can explode directly into
ArrayImpls if passed to explode_with_handlers().
Note that this also helps with deprecating PyBuffer as the fastpath
does not need to call the PyBuffer constructor.
PiperOrigin-RevId: 512788757
In the past we had encountered errors with sharding annotations for CPU/GPU (e.g., crashes; these have been fixed) and when executing in TF eager mode. To work around those we had decided to skip the replicated sharding annotations, which arise often now that all `jit` functions will assume by default replicated shardings. Then we have discovered that we were skipping too many sharding annotations and we made changes to include all inner sharding annotations, but still skip the replicated sharding annotations on inputs and outputs.
It is unsafe to skip annotations, and here we try to include as many sharding annotations as we can. The only case when we cannot include sharding annotations is under TF eager mode. There is should be safe to skip the replicated annotations in eager mode, counting on the fact that we will raise an error if we encounter non-replicated annotations. Such functions must be executed in tf.function mode.
Specifically under tf.function, which is the most important use case, we now include all sharding annotations.
At the same time, I added more tests and I strengthened some tests to check the presence of the sharding annotations in the TF HLO.
PiperOrigin-RevId: 512417862
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203