42 Commits

Author SHA1 Message Date
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
Certain lines in gelu()  would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))

Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964 Annotate nn.initializers
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78 rollback of https://github.com/google/jax/pull/9596
Why? Shape annotations are inaccurate and cause pytype failures

PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459 Annotate nn.initializers 2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31 Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89 [x64] make nn_test pass with strict dtype promotion 2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d Use lax.broadcasted_iota in jax.nn.one_hot.
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b update mypy & related package versions 2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491 jax.nn.glu: fix static argname issue 2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0 Add @jit decorators to functions in jax.nn. 2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f Documentation improvements. 2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e Skip doctests for initializer examples. 2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461 Fix arxiv reference for He initializers. 2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
jax authors
3948fde842 Merge pull request #9052 from jpuigcerver:main
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00
James Bradbury
5dd1c75969 Add batch_axis to variance scaling initializers
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Rolf Jagerman
b810e8be88 Add where= arg to jax.nn.{softmax, log_softmax, normalize}.
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Joan Puigcerver
86e8928e70 Add constant initializer 2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
James Bradbury
f5f0581281
update docstring 2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits 2021-10-20 09:12:37 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Dian Wu
b072ae543f
Merge branch 'main' into complex_init 2021-08-23 13:52:43 +02:00
Dian Wu
5138743e8e Implement variance scaling initializers with complex dtype 2021-08-19 22:33:07 +02:00
Roman Novak
b65f39ca7a Default to jnp.float_ type in nn.initializers. 2021-08-17 20:41:09 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Alexandre Gramfort
f28cf83b0c DOC: fix typo in formula of hard_tanh 2021-03-21 16:29:03 +01:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Adam Paszke
f750969afc Add support for axis names in jax.nn.initializers.variance_scaling
... as well as in a few random functions that it needs (`uniform`,
`normal` and `truncated_normal`). The interface itself doesn't change to
much with the exception of the `shape` arguments of all those functions
now accepting `jax.core.NamedShape` (I didn't move it to be part of the
API just yet, but we can do that any time), which makes it possible to
generate sharded random arrays (in particular the random bits are
different on different shards). I also haven't updated the docstrings,
because I don't know if we're ready to go fully public with this
feature.
2021-02-04 12:38:12 +00:00
Adam Paszke
f812402d37 Add support for named axes in jax.nn.one_hot 2021-02-02 15:57:00 +00:00
Jake VanderPlas
af6da229da DOC: fix some minor formatting issues 2021-01-28 15:20:02 -08:00
Neil Girdhar
8dbb406e59 Improve type annotations
* Add py.typed, which makes type annotations available to users.
* Annotate register_pytree_node, tree_map, tree_multimap, and tree_reduce.
* Add a type annotation overload for vjp
* Annotate jax.scipy.special.
* Annotate lax.scan.
2021-01-13 10:26:35 -05:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
9d2f6148ed Call asarray() rather than array() to avoid host round-trips. 2020-11-24 16:05:48 -08:00
Peter Hawkins
b2808dc8f4 Move jax.nn implementation into jax._src.nn. 2020-10-17 13:45:01 -04:00
jax authors
e9909ce008 Copybara import of the project:
--
a396cfbbd414f6f21f0c7e8a68e6e89d202c0e84 by Peter Hawkins <phawkins@google.com>:

Move jax.nn implementation into jax._src.nn.

PiperOrigin-RevId: 337671917
2020-10-17 10:40:21 -07:00
Peter Hawkins
a396cfbbd4 Move jax.nn implementation into jax._src.nn. 2020-10-17 11:31:19 -04:00