60 Commits

Author SHA1 Message Date
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Matthew Johnson
03abcc7c5c fix typo in test 2022-09-23 14:43:24 -07:00
Matthew Johnson
b6ef90ffdd fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
russbates
5026a810a4
Update nn_test.py
Add parameter to sweep over `approximate` kwarg.
2022-08-11 15:46:34 +01:00
russbates
32b2a8ff00
Update nn_test.py
Add test for fixed integer-type Gelu behaviour.
2022-08-11 15:41:07 +01:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
c762e07491 jax.nn.glu: fix static argname issue 2022-03-17 11:38:13 -07:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
James Bradbury
5dd1c75969 Add batch_axis to variance scaling initializers
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Rolf Jagerman
b810e8be88 Add where= arg to jax.nn.{softmax, log_softmax, normalize}.
This change adds a `where=` argument (analogous to `jnp.sum`) that can be used to specify which elements to include in the calculation.
2021-12-29 15:49:30 +01:00
Matthew Johnson
d8e28400ba fix leak checker interaction with custom_jvp
fixes #8171
2021-12-14 13:00:27 -08:00
James Bradbury
a56aee96ee
fix whitespace 2021-10-27 22:06:06 -07:00
James Bradbury
b3509c19c7
fix typo 2021-10-27 22:05:08 -07:00
James Bradbury
e5853475fe
add functionality test 2021-10-20 09:25:02 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Dian Wu
5138743e8e Implement variance scaling initializers with complex dtype 2021-08-19 22:33:07 +02:00
David Majnemer
f0c30492dc Remove stale limitations
PiperOrigin-RevId: 383652551
2021-07-08 09:42:26 -07:00
Matthew Johnson
e968672740 add tanh to jax.nn package 2021-04-29 08:26:38 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Peter Hawkins
5d6ff8b28d Readd omnistaging friendly versions of large constant tests. 2021-03-10 11:34:42 -05:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Adam Paszke
f812402d37 Add support for named axes in jax.nn.one_hot 2021-02-02 15:57:00 +00:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Peter Hawkins
9b3bbe8359 Adds an approximate=... keyword argument to jax.nn.gelu to select between the approximate and exact formulations of gelu.
Default to the approximate formulation for now.
2020-10-02 09:48:07 -04:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
49cfe2687c
improve concreteness error message for nn.one_hot (#3656)
* improve nn.one_hot and jax.numpy.arange errors

fixes #3654

* deflake

* debug
2020-07-03 20:54:25 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Peter Hawkins
72efa783ab
Fix spurious rank promotion warning. (#2954) 2020-05-04 14:50:08 -04:00
James Bradbury
1cc6b7dd6c
support axis argument in nn.glu (#2879)
* support axis argument in nn.glu

* also add basic correctness test

* Update nn_test.py
2020-05-02 19:33:10 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Matthew Johnson
2d25773c21 add custom_jvp for logaddexp / logaddexp2
fixes #2107, draws from #2356 and #2357, thanks @yingted !

Co-authored-by: Ted Ying <yingted@gmail.com>
2020-04-13 11:20:16 -07:00
Matthew Johnson
1e61ba429d
improve jax.nn.relu differentiation (#2342) 2020-03-03 16:27:53 -08:00
George Necula
08eb0ee030
Disable newly added test on TPU (no float16) (#2262)
Added in #2259
2020-02-19 16:03:10 +01:00
Trevor Cai
eda91a048b
Use input dtype for constants in jax.nn.gelu (#2259) 2020-02-18 22:04:20 -08:00
Tom Hennigan
ca6df306aa
Add jax.nn.one_hot(x, num_classes, dtype). (#2240) 2020-02-15 10:32:00 -08:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
Lukas Prediger
ddc83e0937
Added dtype arg for NN initializer factory methods (#2034)
* Added dtype arg for NN initializer factory methods

Initializer factories in jax/nn/initializers.py (such as
uniform(), normal(), glorot_normal(), etc) now have
an optional `dtype` argument.

The value passed in that argument becomes the
default value for the same `dtype` argument
of the initializer function returned by the factory.

* fixed failed test for delta_orthogonal in d12cdc47
2020-02-04 08:38:38 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00