73 Commits

Author SHA1 Message Date
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Matteo Hessel
0b602c5c4d Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
2024-04-09 03:10:04 -07:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Matteo Hessel
c94ea147f2 Add sparseplus activation to jax.nn.
PiperOrigin-RevId: 616087452
2024-03-15 04:40:38 -07:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
Jake VanderPlas
d59c1f1e21 jax.nn.normalize: deprecate using standard framework 2023-11-08 09:42:23 -08:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00
Peter Hawkins
fd24f976e1 Fix __module__ of jax.nn.initializers.* to be jax.nn.initializers.
Serialization tooling for JAX sometimes prints these names, and we want users to prefer the public names.
2023-02-22 14:13:15 -05:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Neil Girdhar
7869ff4964 Annotate nn.initializers
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78 rollback of https://github.com/google/jax/pull/9596
Why? Shape annotations are inaccurate and cause pytype failures

PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459 Annotate nn.initializers 2022-08-03 20:30:32 -04:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Joan Puigcerver
86e8928e70 Add constant initializer 2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
33e2bed1b4 Fix package exports 2021-09-14 13:55:55 -07:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Matthew Johnson
e968672740 add tanh to jax.nn package 2021-04-29 08:26:38 -07:00
Peter Hawkins
c9f4b27427 Delete jax.nn.functions.
jax.nn.functions was an accidental export; its contents are available directly in the jax.nn namespace.
2020-11-09 19:48:54 -05:00
Peter Hawkins
b07848359b Import jax.nn.functions by default to fix breakage. 2020-10-17 14:51:39 -04:00
Peter Hawkins
b2808dc8f4 Move jax.nn implementation into jax._src.nn. 2020-10-17 13:45:01 -04:00
jax authors
e9909ce008 Copybara import of the project:
--
a396cfbbd414f6f21f0c7e8a68e6e89d202c0e84 by Peter Hawkins <phawkins@google.com>:

Move jax.nn implementation into jax._src.nn.

PiperOrigin-RevId: 337671917
2020-10-17 10:40:21 -07:00
Peter Hawkins
a396cfbbd4 Move jax.nn implementation into jax._src.nn. 2020-10-17 11:31:19 -04:00
Alex Alemi
f7d4063e55 Remove expit, add logsumexp to docs 2020-10-15 11:06:18 -04:00
Alex Alemi
3c69b2d6ab Fixing lint error 2020-10-11 11:35:26 -04:00
Alex Alemi
00e70492e4 Export expit and logsumexp in jax.nn.functions 2020-10-11 11:30:45 -04:00
Peter Hawkins
9b3bbe8359 Adds an approximate=... keyword argument to jax.nn.gelu to select between the approximate and exact formulations of gelu.
Default to the approximate formulation for now.
2020-10-02 09:48:07 -04:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Peter Hawkins
a2fb798dc3
Rename swish to silu, add swish as an alias to silu. (#3673) 2020-07-06 18:08:16 -04:00
Matthew Johnson
49cfe2687c
improve concreteness error message for nn.one_hot (#3656)
* improve nn.one_hot and jax.numpy.arange errors

fixes #3654

* deflake

* debug
2020-07-03 20:54:25 -07:00
Jake Vanderplas
a63b9cc256
Cleanup: deflake interpreters, lib, nn, third_party, and tools (#3327) 2020-06-04 15:27:48 -07:00
James Bradbury
4f5547dd85
Don't AD through max-subtraction in softmax (#2260)
* Don't AD through max-subtraction in softmax

* Also stop-grad the max in logsumexp
2020-06-03 17:00:54 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
36e7fad1e2
Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140)
* Add a primitive integer_pow() for values raised to fixed integer scalar.

Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().

Fixes #3136

```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312

In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda  ; a.
  let b = integer_pow[ y=-7 ] a
      c = mul -6.0 b
      d = mul -120.0 c
  in (d,) }

In [5]:
```

* Use x ** 3 in gelu definition.
2020-05-18 17:54:20 -04:00
Ed Schmerling
510af1de64
Fix documentation for nn.elu, nn.celu, and lax.expm1. (#3116) 2020-05-15 20:51:53 -07:00
Yusuke Oda
ccb8d45975
Uses jnp.square instead of power. (#3036)
* Uses multiplication instead of power.

* Uses jnp.square instead of mul and adds check if jnp.square is implemented by mul.
2020-05-12 11:04:53 -04:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
James Bradbury
1cc6b7dd6c
support axis argument in nn.glu (#2879)
* support axis argument in nn.glu

* also add basic correctness test

* Update nn_test.py
2020-05-02 19:33:10 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
Martin Sotir
52c69e88c5
Fix slices in Gated Linear Unit activation (#2341) 2020-04-29 00:16:49 -07:00
Vaibhav Balloli
ef963f06ae
Add ReLU6, Hard sigmoid, swish (#2709) 2020-04-29 00:07:18 -07:00
Matthew Johnson
7a4c4d555c use custom_jvp for internal functions 2020-03-29 20:48:08 -07:00