93 Commits

Author SHA1 Message Date
Jake VanderPlas
75f570e8b0 softmax: document NaN outputs for infinite inputs 2024-05-29 15:00:20 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Matteo Hessel
0b602c5c4d Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
2024-04-09 03:10:04 -07:00
carlosgmartin
9c347b9be1 Let initial=-jnp.inf by default for nn.softmax and nn.log_softmax. 2024-04-08 15:47:29 -04:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Jake VanderPlas
b48aec57ad Require array-like inputs to sparse_plus
We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module.

PiperOrigin-RevId: 617190413
2024-03-19 09:06:18 -07:00
Matteo Hessel
c94ea147f2 Add sparseplus activation to jax.nn.
PiperOrigin-RevId: 616087452
2024-03-15 04:40:38 -07:00
jax authors
0302e4c34d Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
2024-03-08 16:41:22 -08:00
jax authors
1ed58832c2 Merge pull request #20108 from selamw1:modify-nn-doc
PiperOrigin-RevId: 613770878
2024-03-07 18:37:24 -08:00
Selam Waktola
8ac2913296 minor modification for silu and swish func description
Update 'aka' only inside functions.py

modify SiLU (a.k.a. swish) activation function.
to
SiLU (aka swish) activation function.
2024-03-07 15:40:39 -08:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Anselm Levskaya
04f6bfa460 Prevent accidental upcasting in jax.nn.initializers.
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars.  When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.

PiperOrigin-RevId: 611858446
2024-03-01 14:24:26 -08:00
Jake VanderPlas
a282d586b6 nn.softmax: use double-where when where is specified 2024-01-26 09:45:31 -08:00
jax authors
78b46043b0 Decorate jax.nn.initializers.Initializer as @typing.runtime_checkable
Without this decorator, we get a warning from typeguard:

```
.../typeguard/_checkers.py:474: UserWarning: Typeguard cannot check the Initializer protocol because it is a non-runtime protocol. If you would like to type check this protocol, please use @typing.runtime_checkable
```

PiperOrigin-RevId: 598588778
2024-01-15 05:44:18 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
95bc2ba1b9 Inline sigmoid, isfinite, and isnan in jaxprs.
In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose.

There's no reason to include stuff like this in a jaxpr:
```
          cxd:bool[8,16] = pjit[
            jaxpr={ lambda ; cxe:f32[8,16]. let
                cxf:bool[8,16] = is_finite cxe
              in (cxf,) }
            name=isfinite
          ] cxc
```

PiperOrigin-RevId: 587047955
2023-12-01 10:23:56 -08:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
jax authors
e1a04e4496 Make args in doc consistent with code
PiperOrigin-RevId: 581324250
2023-11-10 11:44:55 -08:00
Jake VanderPlas
d59c1f1e21 jax.nn.normalize: deprecate using standard framework 2023-11-08 09:42:23 -08:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Jake VanderPlas
270cc6014c Update internal callers to avoid PRNGKeyArray 2023-09-13 14:05:42 -07:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Peter Hawkins
5d680d4591 Improve jax.nn documentation.
Fixes https://github.com/google/jax/issues/17171
2023-08-18 15:56:35 -04:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
c474de424a jax.nn.softmax: fix fill value when where is specified 2023-06-01 10:18:05 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
George Necula
00c56c27c6 [shape_poly] Add shape polymorphism support for jax.nn.one_hot 2023-04-12 13:31:22 +03:00
Ikko Eltociear Ashimine
69948eb06b
fix typo in functions.py
Indicies -> Indices
2023-04-11 11:33:27 +09:00
Matthew Johnson
9c39b6f70c update relu6 grad at 0 and 6 to match pytorch convention 2023-03-07 17:30:17 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Yuxin Wu
d5a058c7a8 doc improvement on initilaizer
PiperOrigin-RevId: 491947286
2022-11-30 10:03:23 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f3710aeb5f add paper link about grad-relu-at-zero 2022-09-14 14:16:01 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00