George Necula
9261edaf94
[shape_poly] Cleanups for the shape polymorphism APIs.
...
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
c474de424a
jax.nn.softmax: fix fill value when where is specified
2023-06-01 10:18:05 -07:00
Matthew Johnson
e0d2736e37
add custom_jvp for jax.nn.softmax
...
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Jake VanderPlas
1c7f8efce6
Add test framework for module attribute
2023-04-21 13:20:16 -07:00
Jake VanderPlas
5521423d92
Change np.prod->math.prod
...
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
George Necula
00c56c27c6
[shape_poly] Add shape polymorphism support for jax.nn.one_hot
2023-04-12 13:31:22 +03:00
Ikko Eltociear Ashimine
69948eb06b
fix typo in functions.py
...
Indicies -> Indices
2023-04-11 11:33:27 +09:00
Matthew Johnson
9c39b6f70c
update relu6 grad at 0 and 6 to match pytorch convention
2023-03-07 17:30:17 -08:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887
Added shape error checking for compute_fans
...
Update tests/nn_test.py
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
4389216d0c
Remove typing_extensions dependency
2022-12-05 15:42:26 -08:00
Yuxin Wu
d5a058c7a8
doc improvement on initilaizer
...
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
Yash Katariya
a419e1917a
Use jax.Array by default for doctests
...
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f3710aeb5f
add paper link about grad-relu-at-zero
2022-09-14 14:16:01 -07:00
Peter Hawkins
57b5acf1b6
Roll forward: Upgrade logistic into a primitive.
...
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.
PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234
Rollback of upgrade logistic (sigmoid) function into a lax primitive.
...
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c
Upgrade logistic (sigmoid) function into a lax primitive.
...
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.
PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Peter Hawkins
29d03160e3
Remove _ prefix from functions in jax._src.dtypes.
...
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
...
Certain lines in gelu() would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))
Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964
Annotate nn.initializers
...
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78
rollback of https://github.com/google/jax/pull/9596
...
Why? Shape annotations are inaccurate and cause pytype failures
PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459
Annotate nn.initializers
2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31
Fixes the JAX implementation of CELU returning NaN gradients for input
...
values >= 88.7229.
When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.
PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89
[x64] make nn_test pass with strict dtype promotion
2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d
Use lax.broadcasted_iota in jax.nn.one_hot.
...
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b
update mypy & related package versions
2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27
Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning.
2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491
jax.nn.glu: fix static argname issue
2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa
Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
...
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0
Add @jit decorators to functions in jax.nn.
2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f
Documentation improvements.
2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e
Skip doctests for initializer examples.
2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461
Fix arxiv reference for He initializers.
2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081
Document jax.nn.initializers.
2022-03-07 17:26:04 -05:00
jax authors
3948fde842
Merge pull request #9052 from jpuigcerver:main
...
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00
James Bradbury
5dd1c75969
Add batch_axis to variance scaling initializers
...
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Rolf Jagerman
b810e8be88
Add where=
arg to jax.nn.{softmax, log_softmax, normalize}.
...
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Joan Puigcerver
86e8928e70
Add constant initializer
2021-12-27 12:26:37 +00:00
Peter Hawkins
4e21922055
Use imports relative to the jax
package consistently, rather than .
-relative imports.
...
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.
PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
40d6f5ed90
Tighten up dtypes across the package
2021-10-29 13:50:30 -07:00
James Bradbury
f5f0581281
update docstring
2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits
2021-10-20 09:12:37 -07:00
Peter Hawkins
a84426cb8f
Switch internal users of jax.ops.index_... to use x.at[x].set() APIs.
2021-09-13 19:48:29 -04:00