109 Commits

Author SHA1 Message Date
Jake Vanderplas
6b471e2ac6
Cleanup: define type lists in test_util & use in several test files. (#3616) 2020-07-07 17:01:38 -07:00
Matthew Johnson
ba1b5ce8de
skip some ode tests on gpu for speed (#3629) 2020-07-01 11:26:44 -07:00
Jake Vanderplas
09d128edb3
Cleanup: remove some test interdependence (#3600) 2020-06-29 16:22:05 -07:00
Jake VanderPlas
e9aac7bbee Use re.search and include test class name 2020-06-29 11:08:57 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
32e419d189
Fix eigh JVP to ensure that both the primal and tangents of the eigen… (#3550)
* Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real.

Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types.

* Cast input to static indexing grad tests to a JAX array so new type check passes.
2020-06-25 08:14:54 -04:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Peter Hawkins
972c7fda67
Fix bug where jnp.array returned a classic NumPy array, sometimes wit… (#3283)
* Fix bug where jnp.array returned a classic NumPy array, sometimes with the wrong type.

Unconditionally calls `device_put`, because `lax.convert_element_type` has a fast path that sometimes fails to lead to a `device_put`.

Improve the test for `jnp.array` and its test harness.
2020-06-01 19:29:26 -04:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bb2127cebd
Future-proof view test against signaling NaNs (#3178) 2020-05-21 09:20:59 -07:00
Jake Vanderplas
6e3c8b1d9b
Fix arr.view() on TPU & improve tests (#3141) 2020-05-21 06:40:24 -07:00
Jake Vanderplas
e675f804ff
Add support for 8- and 16-bit output in _random_bits (#3090) 2020-05-15 19:09:43 -07:00
Peter Hawkins
22d14fd7dd
Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080) 2020-05-13 14:00:44 -04:00
joao guilherme
d2f84d635b
Change instances of onp to np and np to jnp (#3044) 2020-05-12 20:37:05 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
Peter Hawkins
9174684253
Cache test_utils.format_shape_and_dtype_string. (#2959)
A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.)
2020-05-04 21:08:34 -04:00
Peter Hawkins
d61d6f44dc
Fix a number of flaky tests. (#2953)
* relax some test tolerances.
* disable 'random' preconditioner in CG test (#2951).
* ensure that scatter and top-k tests don't create ties.
2020-05-04 14:34:08 -04:00
Peter Hawkins
9802d7321c
Update XLA. (#2927) 2020-05-01 21:08:56 -04:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Matthew Johnson
2d25773c21 add custom_jvp for logaddexp / logaddexp2
fixes #2107, draws from #2356 and #2357, thanks @yingted !

Co-authored-by: Ted Ying <yingted@gmail.com>
2020-04-13 11:20:16 -07:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Stephan Hoyer
dd92a03713
Docstring for test_util.check_grads (#2656)
Fixes https://github.com/google/jax/issues/2648
2020-04-09 10:18:07 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Matthew Johnson
f2de1bf345 add trace state check tearDown to JaxTestCase 2020-04-02 22:01:43 -07:00
George Necula
0c53ce9def Disable test with float16 on TPU 2020-04-02 12:15:25 +03:00
Tzu-Wei Sung
8c4a938cfa
Implement np.ldexp and np.frexp. (#1529)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-04-01 15:29:48 -04:00
Matthew Johnson
7b0ee9a5ac
improve implementation of MVN logpdf (#2481)
fixes #2314

I also added a bit more test coverage, but not a ton: scipy has
different batch shape semantics and default arguments than I might
expect, so I didn't bother to implement those (and left some test cases
commented out).

I ran into this surprising scipy bug:

```python
In [1]: from scipy.stats import multivariate_normal

In [2]: import numpy as np

In [3]: args = [np.array(1., np.float32), np.array(2., np.float64), np.array(3., np.float64)]

In [4]: print([x.shape for x in args])
[(), (), ()]

In [5]: multivariate_normal.logpdf(*args)
Out[5]: -1.6349113442053944

In [6]: print([x.shape for x in args])
[(), (1,), (1, 1)]
```

Mutated arguments! But it depends on dtype promotion:

```python
In [7]: args = [np.array(1., np.float32), np.array(2., np.float32), np.array(3., np.float32)]

In [8]: print([x.shape for x in args])
[(), (), ()]

In [9]: multivariate_normal.logpdf(*args)
Out[9]: -1.6349113442053944

In [10]: print([x.shape for x in args])
[(), (), ()]
```
2020-03-21 15:42:59 -07:00
George Necula
78c1f6b08d
Increased tolerance for testScipySpecialFun (#2454)
Prevent failures on TPU
2020-03-19 08:54:37 +01:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Daniel Johnson
2dfeaeb63f
Allow zero tolerance for jax.test_util.tolerance (#2393)
Currently, if a user passes any falsy value to jax.test_util.tolerance,
it is changed to the default value. This makes sense when the value
passed is None, but not when the value passed is 0 (which indicates
a desired tolerance of exactly 0).

Disables failing tests for now.
2020-03-11 13:19:46 -07:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
George Necula
20f9230f6e Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
2020-02-11 10:06:08 +01:00
George Necula
86984b37dd
Merge pull request #2169 from gnecula/bug_fix
Disabled tests known to fail on Mac, and optionally slow tests.
2020-02-06 18:17:04 +01:00
George Necula
b79c7948ee Removed dependency on distutils.strtobool 2020-02-06 17:27:46 +01:00
George Necula
ae3003e9d4 Simplify bound_subjaxprs.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.

We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)

This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
2020-02-06 09:34:53 +01:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Skye Wanderman-Milne
891aecb941
Add test utilities for counting compilations. (#1895)
Also uses the new utilities to check that pmap doesn't compile constant computations.
2019-12-19 11:19:58 -08:00
Peter Hawkins
d8d3a7bc87
Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881) 2019-12-17 17:20:51 -05:00
Stephan Hoyer
6ac1c569e8
Use HIGHEST precision for dot_general in linalg JVP rules (#1835) 2019-12-10 00:38:18 -08:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
Peter Hawkins
fb79d56ace
Fixes to type handling. (#1824)
* Fixes to type handling.

* Specify exactly which types to test in lax_test.py, rather than relying on non-x64 mode to squash unsupported types.
* Fix some excessive promotions in jax.numpy.
* Fix some buggy RNGs that returned the wrong type for complex inputs.
2019-12-06 14:49:27 -05:00
Peter Hawkins
ff94b4442a
Remove np._promote_args_like, and replace its users with a newer _pro… (#1802)
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.

We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.

For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.

However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.

* Relax test tolerances.
2019-12-03 10:05:51 -05:00
Peter Hawkins
f3c8af49e7
Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794)
* Fix bugs in handling of convolutions whose LHS has spatial size 0.

* Use onp.shape to compute shapes.
2019-12-02 14:43:43 -05:00
George Necula
2bb74b627e Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output 2019-11-26 14:05:08 +01:00
George Necula
603258ebb8 Fixed a couple of tests 2019-11-26 13:56:58 +01:00
George Necula
5c15dda2c9 Changed api.make_jaxpr to return a TypedJaxpr
* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
  Before:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [*]}
  After this change:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [1.0]}
2019-11-26 09:17:03 +01:00