If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
This change was made after the most recent release, so is safe
to remove. Casting float to int potentially changes intentional
beavior: e.g. NaN casts to 0. Some downstream users currently
use NaN to mark rows which should have no one-hot entry.
This was useful during the transition to typed PRNG keys, but
is no longer necessary. It also makes generated HTML docs
confusing: it's better to just use Array as we expect users to.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.
PiperOrigin-RevId: 676944344
Imported from GitHub PR https://github.com/google/jax/pull/22330
The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly.
cc. @Cjkkkk
Copybara import of the project:
--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:
Remove logic of combining bias and mask
Merging this change closes#22330
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
Imported from GitHub PR https://github.com/google/jax/pull/21371
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
Copybara import of the project:
--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:
Add new SDPA API to jax.nn
Merging this change closes#21371
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module.
PiperOrigin-RevId: 617190413
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars. When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.
PiperOrigin-RevId: 611858446
Without this decorator, we get a warning from typeguard:
```
.../typeguard/_checkers.py:474: UserWarning: Typeguard cannot check the Initializer protocol because it is a non-runtime protocol. If you would like to type check this protocol, please use @typing.runtime_checkable
```
PiperOrigin-RevId: 598588778
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose.
There's no reason to include stuff like this in a jaxpr:
```
cxd:bool[8,16] = pjit[
jaxpr={ lambda ; cxe:f32[8,16]. let
cxf:bool[8,16] = is_finite cxe
in (cxf,) }
name=isfinite
] cxc
```
PiperOrigin-RevId: 587047955