`jax.make_array_from_single_device_arrays` should not allow passing more than one array on the same device as that would lead to an invalid array. While some of this case is already detected by later checks (e.g., `ArrayImpl._check_and_rearrange`), this CL explicitly checks the device list before calling IFRT so that we don't create an invalid IFRT array to begin with.
PiperOrigin-RevId: 647772472
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.
Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.
PiperOrigin-RevId: 624763603
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.
PiperOrigin-RevId: 560828993
The semantics are as follow:
* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings
* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.
This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.
PiperOrigin-RevId: 540705660
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.
Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.
PiperOrigin-RevId: 535315975
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
In that case, reshard the array and then create a host local array from that.
Also improve the shard mismatch error that jax.Array raises.
PiperOrigin-RevId: 531397741