19 Commits

Author SHA1 Message Date
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Matthew Johnson
2d25773c21 add custom_jvp for logaddexp / logaddexp2
fixes #2107, draws from #2356 and #2357, thanks @yingted !

Co-authored-by: Ted Ying <yingted@gmail.com>
2020-04-13 11:20:16 -07:00
Matthew Johnson
1e61ba429d
improve jax.nn.relu differentiation (#2342) 2020-03-03 16:27:53 -08:00
George Necula
08eb0ee030
Disable newly added test on TPU (no float16) (#2262)
Added in #2259
2020-02-19 16:03:10 +01:00
Trevor Cai
eda91a048b
Use input dtype for constants in jax.nn.gelu (#2259) 2020-02-18 22:04:20 -08:00
Tom Hennigan
ca6df306aa
Add jax.nn.one_hot(x, num_classes, dtype). (#2240) 2020-02-15 10:32:00 -08:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
Lukas Prediger
ddc83e0937
Added dtype arg for NN initializer factory methods (#2034)
* Added dtype arg for NN initializer factory methods

Initializer factories in jax/nn/initializers.py (such as
uniform(), normal(), glorot_normal(), etc) now have
an optional `dtype` argument.

The value passed in that argument becomes the
default value for the same `dtype` argument
of the initializer function returned by the factory.

* fixed failed test for delta_orthogonal in d12cdc47
2020-02-04 08:38:38 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
82dbf91311 add tests for #1640, adapt make_jaxpr staging 2019-12-31 11:53:02 -08:00
Lechao Xiao
9a0ed06647 Add Delta orthogonal initialization (#1838) 2019-12-17 16:38:32 -08:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
James Bradbury
ce34cb73c7 fix orthogonal initializer and reenable test 2019-10-29 11:34:20 -07:00
Peter Hawkins
0dd720cd8a
Disable some tests that fail. (#1587)
Add a BUILD rule for experimental/vectorize.py.
2019-10-29 11:04:55 -04:00
root
43be8d8ef8 Fix NaNs in grad(jax.nn.elu) for large inputs. 2019-10-21 11:48:58 +00:00
James Bradbury
23f06b417f add initializer tests 2019-10-03 12:01:21 -07:00
James Bradbury
b82673dcdb add check_dtypes 2019-09-27 12:11:18 -04:00
James Bradbury
fa0a684af6 address comments 2019-09-21 10:11:53 -04:00
James Bradbury
9ef53120d2 add regression tests 2019-09-21 01:04:26 -04:00