also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.
PiperOrigin-RevId: 735957604
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.
PiperOrigin-RevId: 735909962
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.
It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.
This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.
PiperOrigin-RevId: 735381736
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)
Also make some formatting changes in scan lowering to make it easier to debug.
PiperOrigin-RevId: 734542862
In this case, the example boils down to:
```
inp1 = f32[16@x, 4]
inp2 = f32[4]
def f(x: f32[4], y: f32[4])
return jnp.concat([x, y], axis=-1)
vmap(f, in_axes=(0, None))(inp1)
```
This example was breaking in concat batching rule because we didn't broadcast with the right sharding.
PiperOrigin-RevId: 733536944
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.
There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.
PiperOrigin-RevId: 731732301
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.
PiperOrigin-RevId: 730971127
(Part of general cleanups of the lax.linalg submodule.)
This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility
We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).
Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.
PiperOrigin-RevId: 730928808
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.
* Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output
* Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.
PiperOrigin-RevId: 730291595
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
* `bitcast_convert_element_type`
* `cumsum`
* `cumlogsumexp`
* `cumprod`
* `cummax`
* `cummin`
* `reduce_window`
* `reduce_window_sum`
* `reduce_window_max`
* `reduce_window_min`
* `select_and_gather_add`
For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.
PiperOrigin-RevId: 729910108
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).
This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.
Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
In this change, we update schur, triangular_solve, tridiagonal, and tridiagonal_solve. I batched these ones since they're all pretty straightforward.
PiperOrigin-RevId: 729572705
To be consistent with other rule registration helpers, `unop_dtype_rule` should pass through its kwargs to the `result_dtype` callable.
PiperOrigin-RevId: 729483613
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.
PiperOrigin-RevId: 729471970
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483