125 Commits

Author SHA1 Message Date
jax authors
872e6c0ec4 Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
carlosgmartin
96d3447e89 Add mode='fan_geo_avg' to nn.initializers.variance_scaling. 2025-01-31 17:52:22 -05:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
cb10710c92 Remove casting from jax.nn.one_hot
This change was made after the most recent release, so is safe
to remove. Casting float to int potentially changes intentional
beavior: e.g. NaN casts to 0. Some downstream users currently
use NaN to mark rows which should have no one-hot entry.
2024-12-23 07:33:49 -08:00
Jake VanderPlas
8c3c441ee4 jax.nn.one_hot: deprecate non-integer inputs 2024-12-19 07:11:31 -08:00
carlosgmartin
08801147f1 Add test of relu grad at zero. Update paper links. 2024-12-10 19:39:47 -05:00
Jake VanderPlas
fee272e550 Remove internal KeyArray alias
This was useful during the transition to typed PRNG keys, but
is no longer necessary. It also makes generated HTML docs
confusing: it's better to just use Array as we expect users to.
2024-11-20 10:30:12 -08:00
Jake VanderPlas
e9acaa8484 Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
2024-11-04 10:55:21 -08:00
Yash Katariya
4db212d2c6 Add _sharding argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
This is required because `jax.nn.one_hot` calls into `broascasted_iota`.

PiperOrigin-RevId: 687152343
2024-10-17 21:16:51 -07:00
jax authors
81d2fbe094 Merge pull request #23740 from kaixih:dbias_bwd_batcher
PiperOrigin-RevId: 681583770
2024-10-02 14:04:19 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
kaixih
b7e26ba3ee fix dbias in bwd_batcher 2024-09-20 18:07:55 +00:00
kaixih
541b3a3f75 New feature 2024-09-11 19:56:20 +00:00
kaixih
2d2cbbc5fb Relax q_seqlen and kv_seqlen 2024-09-05 17:43:22 +00:00
jax authors
b9e6eb59be Merge pull request #22516 from kaixih:support_variable_seqlen
PiperOrigin-RevId: 666394369
2024-08-22 10:08:08 -07:00
kaixih
558000df7c Support variable sequence lengths 2024-08-21 18:25:55 +00:00
Roy Frostig
371935cc10 update README and several docs to typed RNG keys 2024-08-11 08:09:47 -07:00
Gleb Pobudzey
d28d14917e Fix error message in dot_product_attention
PiperOrigin-RevId: 660960409
2024-08-08 13:30:21 -07:00
Jake VanderPlas
53af0d4d90 CI: fix mypy errors 2024-08-07 15:15:45 -07:00
kaixih
9f9e3e6d4e Address comments 2024-08-02 19:55:28 +00:00
kaixih
6ff6501aa2 Init commit 2024-08-01 19:39:34 +00:00
jax authors
7d8b8578b5 Merge pull request #22477 from kaixih:support_gqa
PiperOrigin-RevId: 658130108
2024-07-31 13:50:49 -07:00
kaixih
cf5bcc7ad8 Support GQA and MQA 2024-07-29 17:17:22 +00:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Michal Kazmierski
61374c92ad Fix error message in jax.nn.dot_product_attention when the inputs have different dtypes.
PiperOrigin-RevId: 655553414
2024-07-24 07:13:15 -07:00
Dan Foreman-Mackey
556cc23fa5 Fix lint at head.
It looks like https://github.com/google/jax/pull/22330 introduced some
mypy lint. This PR fixes it.
2024-07-16 10:53:49 -04:00
Kaixi Hou
09531d2ff8 PR #22330: [NVIDIA] Remove logic of combining bias and mask
Imported from GitHub PR https://github.com/google/jax/pull/22330

The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly.

cc. @Cjkkkk
Copybara import of the project:

--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:

Remove logic of combining bias and mask

Merging this change closes #22330

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
2024-07-16 07:19:01 -07:00
Sebastian Bodenstein
d219f450a0 Fix documentation for dot_product_attention.
PiperOrigin-RevId: 650631839
2024-07-09 08:09:15 -07:00
Jake VanderPlas
7ec87892b5 fix lint errors at HEAD 2024-07-08 08:12:30 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Jake VanderPlas
75f570e8b0 softmax: document NaN outputs for infinite inputs 2024-05-29 15:00:20 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Matteo Hessel
0b602c5c4d Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
2024-04-09 03:10:04 -07:00
carlosgmartin
9c347b9be1 Let initial=-jnp.inf by default for nn.softmax and nn.log_softmax. 2024-04-08 15:47:29 -04:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Jake VanderPlas
b48aec57ad Require array-like inputs to sparse_plus
We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module.

PiperOrigin-RevId: 617190413
2024-03-19 09:06:18 -07:00
Matteo Hessel
c94ea147f2 Add sparseplus activation to jax.nn.
PiperOrigin-RevId: 616087452
2024-03-15 04:40:38 -07:00
jax authors
0302e4c34d Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
2024-03-08 16:41:22 -08:00
jax authors
1ed58832c2 Merge pull request #20108 from selamw1:modify-nn-doc
PiperOrigin-RevId: 613770878
2024-03-07 18:37:24 -08:00
Selam Waktola
8ac2913296 minor modification for silu and swish func description
Update 'aka' only inside functions.py

modify SiLU (a.k.a. swish) activation function.
to
SiLU (aka swish) activation function.
2024-03-07 15:40:39 -08:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Anselm Levskaya
04f6bfa460 Prevent accidental upcasting in jax.nn.initializers.
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars.  When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.

PiperOrigin-RevId: 611858446
2024-03-01 14:24:26 -08:00
Jake VanderPlas
a282d586b6 nn.softmax: use double-where when where is specified 2024-01-26 09:45:31 -08:00
jax authors
78b46043b0 Decorate jax.nn.initializers.Initializer as @typing.runtime_checkable
Without this decorator, we get a warning from typeguard:

```
.../typeguard/_checkers.py:474: UserWarning: Typeguard cannot check the Initializer protocol because it is a non-runtime protocol. If you would like to type check this protocol, please use @typing.runtime_checkable
```

PiperOrigin-RevId: 598588778
2024-01-15 05:44:18 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
95bc2ba1b9 Inline sigmoid, isfinite, and isnan in jaxprs.
In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose.

There's no reason to include stuff like this in a jaxpr:
```
          cxd:bool[8,16] = pjit[
            jaxpr={ lambda ; cxe:f32[8,16]. let
                cxf:bool[8,16] = is_finite cxe
              in (cxf,) }
            name=isfinite
          ] cxc
```

PiperOrigin-RevId: 587047955
2023-12-01 10:23:56 -08:00