29 Commits

Author SHA1 Message Date
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
49cfe2687c
improve concreteness error message for nn.one_hot (#3656)
* improve nn.one_hot and jax.numpy.arange errors

fixes #3654

* deflake

* debug
2020-07-03 20:54:25 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Peter Hawkins
72efa783ab
Fix spurious rank promotion warning. (#2954) 2020-05-04 14:50:08 -04:00
James Bradbury
1cc6b7dd6c
support axis argument in nn.glu (#2879)
* support axis argument in nn.glu

* also add basic correctness test

* Update nn_test.py
2020-05-02 19:33:10 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Matthew Johnson
2d25773c21 add custom_jvp for logaddexp / logaddexp2
fixes #2107, draws from #2356 and #2357, thanks @yingted !

Co-authored-by: Ted Ying <yingted@gmail.com>
2020-04-13 11:20:16 -07:00
Matthew Johnson
1e61ba429d
improve jax.nn.relu differentiation (#2342) 2020-03-03 16:27:53 -08:00
George Necula
08eb0ee030
Disable newly added test on TPU (no float16) (#2262)
Added in #2259
2020-02-19 16:03:10 +01:00
Trevor Cai
eda91a048b
Use input dtype for constants in jax.nn.gelu (#2259) 2020-02-18 22:04:20 -08:00
Tom Hennigan
ca6df306aa
Add jax.nn.one_hot(x, num_classes, dtype). (#2240) 2020-02-15 10:32:00 -08:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
Lukas Prediger
ddc83e0937
Added dtype arg for NN initializer factory methods (#2034)
* Added dtype arg for NN initializer factory methods

Initializer factories in jax/nn/initializers.py (such as
uniform(), normal(), glorot_normal(), etc) now have
an optional `dtype` argument.

The value passed in that argument becomes the
default value for the same `dtype` argument
of the initializer function returned by the factory.

* fixed failed test for delta_orthogonal in d12cdc47
2020-02-04 08:38:38 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
82dbf91311 add tests for #1640, adapt make_jaxpr staging 2019-12-31 11:53:02 -08:00
Lechao Xiao
9a0ed06647 Add Delta orthogonal initialization (#1838) 2019-12-17 16:38:32 -08:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
James Bradbury
ce34cb73c7 fix orthogonal initializer and reenable test 2019-10-29 11:34:20 -07:00
Peter Hawkins
0dd720cd8a
Disable some tests that fail. (#1587)
Add a BUILD rule for experimental/vectorize.py.
2019-10-29 11:04:55 -04:00
root
43be8d8ef8 Fix NaNs in grad(jax.nn.elu) for large inputs. 2019-10-21 11:48:58 +00:00
James Bradbury
23f06b417f add initializer tests 2019-10-03 12:01:21 -07:00
James Bradbury
b82673dcdb add check_dtypes 2019-09-27 12:11:18 -04:00
James Bradbury
fa0a684af6 address comments 2019-09-21 10:11:53 -04:00
James Bradbury
9ef53120d2 add regression tests 2019-09-21 01:04:26 -04:00