By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
@curry is opaque to pytype.
Fix a false positive type error that turns up because pytype doesn't really understand that a functools.partial is a kind of Callable.
PiperOrigin-RevId: 513697380
In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.
Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.
Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
None of these appear to have public users, and this module is not included in the deprecation policy.
Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.
PiperOrigin-RevId: 507584264
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.