Peter Hawkins
29d03160e3
Remove _ prefix from functions in jax._src.dtypes.
...
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
...
Certain lines in gelu() would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))
Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964
Annotate nn.initializers
...
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78
rollback of https://github.com/google/jax/pull/9596
...
Why? Shape annotations are inaccurate and cause pytype failures
PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459
Annotate nn.initializers
2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31
Fixes the JAX implementation of CELU returning NaN gradients for input
...
values >= 88.7229.
When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.
PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89
[x64] make nn_test pass with strict dtype promotion
2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d
Use lax.broadcasted_iota in jax.nn.one_hot.
...
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b
update mypy & related package versions
2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27
Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning.
2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491
jax.nn.glu: fix static argname issue
2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa
Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
...
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0
Add @jit decorators to functions in jax.nn.
2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f
Documentation improvements.
2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e
Skip doctests for initializer examples.
2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461
Fix arxiv reference for He initializers.
2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081
Document jax.nn.initializers.
2022-03-07 17:26:04 -05:00
jax authors
3948fde842
Merge pull request #9052 from jpuigcerver:main
...
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00
James Bradbury
5dd1c75969
Add batch_axis to variance scaling initializers
...
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Rolf Jagerman
b810e8be88
Add where=
arg to jax.nn.{softmax, log_softmax, normalize}.
...
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Joan Puigcerver
86e8928e70
Add constant initializer
2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055
Use imports relative to the jax
package consistently, rather than .
-relative imports.
...
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.
PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
40d6f5ed90
Tighten up dtypes across the package
2021-10-29 13:50:30 -07:00
James Bradbury
f5f0581281
update docstring
2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits
2021-10-20 09:12:37 -07:00
Peter Hawkins
a84426cb8f
Switch internal users of jax.ops.index_... to use x.at[x].set() APIs.
2021-09-13 19:48:29 -04:00
Dian Wu
b072ae543f
Merge branch 'main' into complex_init
2021-08-23 13:52:43 +02:00
Dian Wu
5138743e8e
Implement variance scaling initializers with complex dtype
2021-08-19 22:33:07 +02:00
Roman Novak
b65f39ca7a
Default to jnp.float_
type in nn.initializers
.
2021-08-17 20:41:09 -07:00
Peter Hawkins
6a6f13e1b0
[JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
...
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Alexandre Gramfort
f28cf83b0c
DOC: fix typo in formula of hard_tanh
2021-03-21 16:29:03 +01:00
Jake VanderPlas
067be89a0c
DOC: minor documentation & formatting fixes
2021-02-23 10:31:44 -08:00
Adam Paszke
f750969afc
Add support for axis names in jax.nn.initializers.variance_scaling
...
... as well as in a few random functions that it needs (`uniform`,
`normal` and `truncated_normal`). The interface itself doesn't change to
much with the exception of the `shape` arguments of all those functions
now accepting `jax.core.NamedShape` (I didn't move it to be part of the
API just yet, but we can do that any time), which makes it possible to
generate sharded random arrays (in particular the random bits are
different on different shards). I also haven't updated the docstrings,
because I don't know if we're ready to go fully public with this
feature.
2021-02-04 12:38:12 +00:00
Adam Paszke
f812402d37
Add support for named axes in jax.nn.one_hot
2021-02-02 15:57:00 +00:00
Jake VanderPlas
af6da229da
DOC: fix some minor formatting issues
2021-01-28 15:20:02 -08:00
Neil Girdhar
8dbb406e59
Improve type annotations
...
* Add py.typed, which makes type annotations available to users.
* Annotate register_pytree_node, tree_map, tree_multimap, and tree_reduce.
* Add a type annotation overload for vjp
* Annotate jax.scipy.special.
* Annotate lax.scan.
2021-01-13 10:26:35 -05:00
Peter Hawkins
3ac809ede3
[JAX] Move jax.util to jax._src_util.
...
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
9d2f6148ed
Call asarray() rather than array() to avoid host round-trips.
2020-11-24 16:05:48 -08:00
Peter Hawkins
b2808dc8f4
Move jax.nn implementation into jax._src.nn.
2020-10-17 13:45:01 -04:00
jax authors
e9909ce008
Copybara import of the project:
...
--
a396cfbbd414f6f21f0c7e8a68e6e89d202c0e84 by Peter Hawkins <phawkins@google.com>:
Move jax.nn implementation into jax._src.nn.
PiperOrigin-RevId: 337671917
2020-10-17 10:40:21 -07:00
Peter Hawkins
a396cfbbd4
Move jax.nn implementation into jax._src.nn.
2020-10-17 11:31:19 -04:00