jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.
PiperOrigin-RevId: 731008883
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.
PiperOrigin-RevId: 730971127
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.
PiperOrigin-RevId: 730939406
(Part of general cleanups of the lax.linalg submodule.)
This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility
We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).
Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.
PiperOrigin-RevId: 730928808
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.
PiperOrigin-RevId: 730824239
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.
Fixes#26298
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.
I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.
PiperOrigin-RevId: 730687231
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.
* Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output
* Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.
PiperOrigin-RevId: 730291595
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
* `bitcast_convert_element_type`
* `cumsum`
* `cumlogsumexp`
* `cumprod`
* `cummax`
* `cummin`
* `reduce_window`
* `reduce_window_sum`
* `reduce_window_max`
* `reduce_window_min`
* `select_and_gather_add`
For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.
PiperOrigin-RevId: 729910108
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).
This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.
Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>