This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.
Fixes#26298
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.
I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.
PiperOrigin-RevId: 730687231
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.
* Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output
* Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.
PiperOrigin-RevId: 730291595
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
* `bitcast_convert_element_type`
* `cumsum`
* `cumlogsumexp`
* `cumprod`
* `cummax`
* `cummin`
* `reduce_window`
* `reduce_window_sum`
* `reduce_window_max`
* `reduce_window_min`
* `select_and_gather_add`
For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.
PiperOrigin-RevId: 729910108
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).
This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.
Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
In this change, we update schur, triangular_solve, tridiagonal, and tridiagonal_solve. I batched these ones since they're all pretty straightforward.
PiperOrigin-RevId: 729572705
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.
PiperOrigin-RevId: 729561359
To be consistent with other rule registration helpers, `unop_dtype_rule` should pass through its kwargs to the `result_dtype` callable.
PiperOrigin-RevId: 729483613
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.
PiperOrigin-RevId: 729471970
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.
PiperOrigin-RevId: 729307115
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.
PiperOrigin-RevId: 728956512