Matthew Johnson
9c39b6f70c
update relu6 grad at 0 and 6 to match pytorch convention
2023-03-07 17:30:17 -08:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887
Added shape error checking for compute_fans
...
Update tests/nn_test.py
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
4389216d0c
Remove typing_extensions dependency
2022-12-05 15:42:26 -08:00
Yuxin Wu
d5a058c7a8
doc improvement on initilaizer
...
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
Yash Katariya
a419e1917a
Use jax.Array by default for doctests
...
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f3710aeb5f
add paper link about grad-relu-at-zero
2022-09-14 14:16:01 -07:00
Peter Hawkins
57b5acf1b6
Roll forward: Upgrade logistic into a primitive.
...
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.
PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234
Rollback of upgrade logistic (sigmoid) function into a lax primitive.
...
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Peter Hawkins
29d03160e3
Remove _ prefix from functions in jax._src.dtypes.
...
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
...
Certain lines in gelu() would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))
Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964
Annotate nn.initializers
...
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78
rollback of https://github.com/google/jax/pull/9596
...
Why? Shape annotations are inaccurate and cause pytype failures
PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459
Annotate nn.initializers
2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31
Fixes the JAX implementation of CELU returning NaN gradients for input
...
values >= 88.7229.
When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.
PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89
[x64] make nn_test pass with strict dtype promotion
2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d
Use lax.broadcasted_iota in jax.nn.one_hot.
...
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b
update mypy & related package versions
2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27
Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning.
2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491
jax.nn.glu: fix static argname issue
2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa
Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
...
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0
Add @jit decorators to functions in jax.nn.
2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f
Documentation improvements.
2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e
Skip doctests for initializer examples.
2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461
Fix arxiv reference for He initializers.
2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081
Document jax.nn.initializers.
2022-03-07 17:26:04 -05:00
jax authors
3948fde842
Merge pull request #9052 from jpuigcerver:main
...
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00
James Bradbury
5dd1c75969
Add batch_axis to variance scaling initializers
...
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Rolf Jagerman
b810e8be88
Add where=
arg to jax.nn.{softmax, log_softmax, normalize}.
...
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Joan Puigcerver
86e8928e70
Add constant initializer
2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055
Use imports relative to the jax
package consistently, rather than .
-relative imports.
...
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.
PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
40d6f5ed90
Tighten up dtypes across the package
2021-10-29 13:50:30 -07:00
James Bradbury
f5f0581281
update docstring
2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits
2021-10-20 09:12:37 -07:00
Peter Hawkins
a84426cb8f
Switch internal users of jax.ops.index_... to use x.at[x].set() APIs.
2021-09-13 19:48:29 -04:00
Dian Wu
b072ae543f
Merge branch 'main' into complex_init
2021-08-23 13:52:43 +02:00
Dian Wu
5138743e8e
Implement variance scaling initializers with complex dtype
2021-08-19 22:33:07 +02:00
Roman Novak
b65f39ca7a
Default to jnp.float_
type in nn.initializers
.
2021-08-17 20:41:09 -07:00
Peter Hawkins
6a6f13e1b0
[JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
...
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Alexandre Gramfort
f28cf83b0c
DOC: fix typo in formula of hard_tanh
2021-03-21 16:29:03 +01:00
Jake VanderPlas
067be89a0c
DOC: minor documentation & formatting fixes
2021-02-23 10:31:44 -08:00
Adam Paszke
f750969afc
Add support for axis names in jax.nn.initializers.variance_scaling
...
... as well as in a few random functions that it needs (`uniform`,
`normal` and `truncated_normal`). The interface itself doesn't change to
much with the exception of the `shape` arguments of all those functions
now accepting `jax.core.NamedShape` (I didn't move it to be part of the
API just yet, but we can do that any time), which makes it possible to
generate sharded random arrays (in particular the random bits are
different on different shards). I also haven't updated the docstrings,
because I don't know if we're ready to go fully public with this
feature.
2021-02-04 12:38:12 +00:00
Adam Paszke
f812402d37
Add support for named axes in jax.nn.one_hot
2021-02-02 15:57:00 +00:00