102 Commits

Author SHA1 Message Date
shuw
17088e9025 Improve after review # 2 2025-02-26 04:48:25 +00:00
shuw
681ee18436 Fix CI 2025-02-25 17:15:31 +00:00
Shu Wang
08012e9c01
Conditionally create mxfp8_configs. 2025-02-21 23:08:22 -06:00
shuw
bfb9d3ca4b Improve based on comment # 1 2025-02-21 17:32:57 +00:00
shuw
332af58765 block_scale_config 2025-02-13 04:35:06 +00:00
jax authors
872e6c0ec4 Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
carlosgmartin
96d3447e89 Add mode='fan_geo_avg' to nn.initializers.variance_scaling. 2025-01-31 17:52:22 -05:00
cjkkkk
fe37afe9dd fix error tolerance 2025-01-22 22:00:16 +00:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Jake VanderPlas
8c3c441ee4 jax.nn.one_hot: deprecate non-integer inputs 2024-12-19 07:11:31 -08:00
carlosgmartin
08801147f1 Add test of relu grad at zero. Update paper links. 2024-12-10 19:39:47 -05:00
kaixih
7409bae64c Adjusted atol/rtol for jax sdpa tests 2024-10-21 17:00:04 +00:00
jax authors
81d2fbe094 Merge pull request #23740 from kaixih:dbias_bwd_batcher
PiperOrigin-RevId: 681583770
2024-10-02 14:04:19 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
kaixih
b7e26ba3ee fix dbias in bwd_batcher 2024-09-20 18:07:55 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
kaixih
541b3a3f75 New feature 2024-09-11 19:56:20 +00:00
Kaixi Hou
8ccc439d4a PR #23223: [NVIDIA] Reduce number of tests for jax.nn.dot_product_attention
Imported from GitHub PR https://github.com/google/jax/pull/23223

While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.

For the new tests, we categorize them as follows:

1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.

Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:

--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:

Reduce attn tests

Merging this change closes #23223

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:12:11 -07:00
kaixih
558000df7c Support variable sequence lengths 2024-08-21 18:25:55 +00:00
kaixih
6ff6501aa2 Init commit 2024-08-01 19:39:34 +00:00
kaixih
cf5bcc7ad8 Support GQA and MQA 2024-07-29 17:17:22 +00:00
Jake VanderPlas
7ec87892b5 fix lint errors at HEAD 2024-07-08 08:12:30 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Matteo Hessel
0b602c5c4d Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
2024-04-09 03:10:04 -07:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Matteo Hessel
c94ea147f2 Add sparseplus activation to jax.nn.
PiperOrigin-RevId: 616087452
2024-03-15 04:40:38 -07:00
Anselm Levskaya
04f6bfa460 Prevent accidental upcasting in jax.nn.initializers.
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars.  When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.

PiperOrigin-RevId: 611858446
2024-03-01 14:24:26 -08:00
Jake VanderPlas
a282d586b6 nn.softmax: use double-where when where is specified 2024-01-26 09:45:31 -08:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
Peter Hawkins
e7f1d29716 Relax some test tolerances for TPU.
PiperOrigin-RevId: 576192162
2023-10-24 10:45:40 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
c474de424a jax.nn.softmax: fix fill value when where is specified 2023-06-01 10:18:05 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Matthew Johnson
9c39b6f70c update relu6 grad at 0 and 6 to match pytorch convention 2023-03-07 17:30:17 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Matthew Johnson
03abcc7c5c fix typo in test 2022-09-23 14:43:24 -07:00
Matthew Johnson
b6ef90ffdd fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
russbates
5026a810a4
Update nn_test.py
Add parameter to sweep over `approximate` kwarg.
2022-08-11 15:46:34 +01:00