functools.partial(jax.arrays.ArrayImpl) with the added benefit
that the new PyExecuteResults type can explode directly into
ArrayImpls if passed to explode_with_handlers().
Note that this also helps with deprecating PyBuffer as the fastpath
does not need to call the PyBuffer constructor.
PiperOrigin-RevId: 512788757
In the past we had encountered errors with sharding annotations for CPU/GPU (e.g., crashes; these have been fixed) and when executing in TF eager mode. To work around those we had decided to skip the replicated sharding annotations, which arise often now that all `jit` functions will assume by default replicated shardings. Then we have discovered that we were skipping too many sharding annotations and we made changes to include all inner sharding annotations, but still skip the replicated sharding annotations on inputs and outputs.
It is unsafe to skip annotations, and here we try to include as many sharding annotations as we can. The only case when we cannot include sharding annotations is under TF eager mode. There is should be safe to skip the replicated annotations in eager mode, counting on the fact that we will raise an error if we encounter non-replicated annotations. Such functions must be executed in tf.function mode.
Specifically under tf.function, which is the most important use case, we now include all sharding annotations.
At the same time, I added more tests and I strengthened some tests to check the presence of the sharding annotations in the TF HLO.
PiperOrigin-RevId: 512417862
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.
This allows them to run in ASAN in more situations.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 511829646