The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.
PiperOrigin-RevId: 676944344
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes#23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
Imported from GitHub PR https://github.com/google/jax/pull/21371
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
Copybara import of the project:
--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:
Add new SDPA API to jax.nn
Merging this change closes#21371
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars. When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.
PiperOrigin-RevId: 611858446
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974