57 Commits

Author SHA1 Message Date
Matthew Johnson
9c39b6f70c update relu6 grad at 0 and 6 to match pytorch convention 2023-03-07 17:30:17 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Yuxin Wu
d5a058c7a8 doc improvement on initilaizer
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f3710aeb5f add paper link about grad-relu-at-zero 2022-09-14 14:16:01 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
Certain lines in gelu()  would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))

Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964 Annotate nn.initializers
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78 rollback of https://github.com/google/jax/pull/9596
Why? Shape annotations are inaccurate and cause pytype failures

PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459 Annotate nn.initializers 2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31 Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89 [x64] make nn_test pass with strict dtype promotion 2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d Use lax.broadcasted_iota in jax.nn.one_hot.
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b update mypy & related package versions 2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491 jax.nn.glu: fix static argname issue 2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0 Add @jit decorators to functions in jax.nn. 2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f Documentation improvements. 2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e Skip doctests for initializer examples. 2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461 Fix arxiv reference for He initializers. 2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
jax authors
3948fde842 Merge pull request #9052 from jpuigcerver:main
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00
James Bradbury
5dd1c75969 Add batch_axis to variance scaling initializers
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Rolf Jagerman
b810e8be88 Add where= arg to jax.nn.{softmax, log_softmax, normalize}.
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Joan Puigcerver
86e8928e70 Add constant initializer 2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
James Bradbury
f5f0581281
update docstring 2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits 2021-10-20 09:12:37 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Dian Wu
b072ae543f
Merge branch 'main' into complex_init 2021-08-23 13:52:43 +02:00
Dian Wu
5138743e8e Implement variance scaling initializers with complex dtype 2021-08-19 22:33:07 +02:00
Roman Novak
b65f39ca7a Default to jnp.float_ type in nn.initializers. 2021-08-17 20:41:09 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Alexandre Gramfort
f28cf83b0c DOC: fix typo in formula of hard_tanh 2021-03-21 16:29:03 +01:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Adam Paszke
f750969afc Add support for axis names in jax.nn.initializers.variance_scaling
... as well as in a few random functions that it needs (`uniform`,
`normal` and `truncated_normal`). The interface itself doesn't change to
much with the exception of the `shape` arguments of all those functions
now accepting `jax.core.NamedShape` (I didn't move it to be part of the
API just yet, but we can do that any time), which makes it possible to
generate sharded random arrays (in particular the random bits are
different on different shards). I also haven't updated the docstrings,
because I don't know if we're ready to go fully public with this
feature.
2021-02-04 12:38:12 +00:00
Adam Paszke
f812402d37 Add support for named axes in jax.nn.one_hot 2021-02-02 15:57:00 +00:00