73 Commits

Author SHA1 Message Date
Jake VanderPlas
d59c1f1e21 jax.nn.normalize: deprecate using standard framework 2023-11-08 09:42:23 -08:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Jake VanderPlas
270cc6014c Update internal callers to avoid PRNGKeyArray 2023-09-13 14:05:42 -07:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Peter Hawkins
5d680d4591 Improve jax.nn documentation.
Fixes https://github.com/google/jax/issues/17171
2023-08-18 15:56:35 -04:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
c474de424a jax.nn.softmax: fix fill value when where is specified 2023-06-01 10:18:05 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
George Necula
00c56c27c6 [shape_poly] Add shape polymorphism support for jax.nn.one_hot 2023-04-12 13:31:22 +03:00
Ikko Eltociear Ashimine
69948eb06b
fix typo in functions.py
Indicies -> Indices
2023-04-11 11:33:27 +09:00
Matthew Johnson
9c39b6f70c update relu6 grad at 0 and 6 to match pytorch convention 2023-03-07 17:30:17 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Yuxin Wu
d5a058c7a8 doc improvement on initilaizer
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f3710aeb5f add paper link about grad-relu-at-zero 2022-09-14 14:16:01 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
russbates
4ce88ec71f
Fix bug with integer typed-Gelu
Certain lines in gelu()  would round down constants if called with integer types (sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype))

Cast the input array to the nearest float-like type to avoid this, as done for trigonometic functions.
2022-08-11 15:22:02 +01:00
Neil Girdhar
7869ff4964 Annotate nn.initializers
This was done to expose an Initializers type annotation that can be used
in other libraries.
2022-08-04 23:17:40 -04:00
Jake VanderPlas
d52017aa78 rollback of https://github.com/google/jax/pull/9596
Why? Shape annotations are inaccurate and cause pytype failures

PiperOrigin-RevId: 465337386
2022-08-04 09:51:18 -07:00
Neil Girdhar
1bd3784459 Annotate nn.initializers 2022-08-03 20:30:32 -04:00
jax authors
b0805a8a31 Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
80d814ab89 [x64] make nn_test pass with strict dtype promotion 2022-06-16 10:56:49 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
6e6f693e6d Use lax.broadcasted_iota in jax.nn.one_hot.
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Jake VanderPlas
7972b98a7b update mypy & related package versions 2022-04-15 08:55:06 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491 jax.nn.glu: fix static argname issue 2022-03-17 11:38:13 -07:00
Peter Hawkins
728e4fd3fa Remove @jit decorator on gelu and softmax temporarily while debugging test failures.
PiperOrigin-RevId: 433630873
2022-03-09 18:48:03 -08:00
Peter Hawkins
626c03fee0 Add @jit decorators to functions in jax.nn. 2022-03-09 12:04:09 -05:00
Peter Hawkins
80aec7b25f Documentation improvements. 2022-03-08 09:37:33 -05:00
Peter Hawkins
ad5144f59e Skip doctests for initializer examples. 2022-03-08 08:29:43 -05:00
Peter Hawkins
5cfbb9f461 Fix arxiv reference for He initializers. 2022-03-07 19:13:01 -05:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
jax authors
3948fde842 Merge pull request #9052 from jpuigcerver:main
PiperOrigin-RevId: 430680329
2022-02-24 05:37:02 -08:00