135 Commits

Author SHA1 Message Date
Adam Paszke
e0357283a6 Speed up test generation 7x
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.

The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
2021-04-14 15:58:05 +00:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
jax authors
cf9b77f1de Merge pull request #5998 from zhangqiaorjc:dev_put_count
PiperOrigin-RevId: 362143966
2021-03-10 14:36:55 -08:00
Qiao Zhang
9577860169 Add jtu.count_device_put for tests to count device_put. 2021-03-09 14:45:01 -08:00
Jake VanderPlas
dbdb189de1 jnp.piecewise: support scalar inputs 2021-03-09 13:25:38 -08:00
jax authors
144a633e42 Merge pull request #5405 from shoyer:check-grads-err-msg
PiperOrigin-RevId: 356818579
2021-02-10 13:51:23 -08:00
Stephan Hoyer
db6405c746 Better error messages for test_util.check_grads()
Rather than merely reporting a failure in check_grads(), we now
report the *specific* check that failed, e.g., "JVP tangent" or
"VJP of JVP cotangent projection". Gradient tests often fail for
spurious reasons (e.g., due to insufficient precision), so this should
be helpful for debugging.

I tested this manually by relaxing the tolerance for a test in
`linalg_test.py`.
2021-02-10 12:19:33 -08:00
Adam Paszke
0b7febea39 Add argument donation for xmap
Also, pass the body to XLA JIT when no parallel resources are used.
There is no reason to not do that given that we already require users to
pay the price of making their code jittable.
2021-02-08 12:46:45 +00:00
Jake VanderPlas
5140426c7c Support non-scalar fill values in jnp.full() & jnp.full_like() 2021-02-05 10:07:41 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
George Necula
20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
jax authors
9a8ee95c08 Merge pull request #4419 from rsepassi:master
PiperOrigin-RevId: 334846970
2020-10-01 10:39:51 -07:00
Ryan Sepassi
97592c86f5 Add --exclude_test_targets to test_util 2020-09-29 07:57:20 -07:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Stephan Hoyer
877053d8ab
Add jax.linear_transpose (#3398)
* Add jax.linear_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>

* add failing test for complex numbers

* Add picky dtype check for linear_transpose

* Lint fix

* Allow truncating dtypes to match inputs in linear_transpose

* Fix typo in shape check error

* improve docstring

* Don't support integer inputs; better docstring

* fixup

* Fix doctest

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-09-16 20:29:19 -07:00
Stephan Hoyer
6bd3216b26
Simplify the interface for host_callback.id_tap (#4101)
* Simplify the internal interface for host_callback.id_tap

This is a breaking change for `id_tap` users (but not `id_print` users).

This makes it easier to use (and type check)  ``tap_func``, because the
expected signature is now ``tap_func(arg, transforms)`` vs
``tap_func(arg, *, transforms, **kwargs)``.

Most of the test changes are just adding whitespace/indentation, but I've
also slightly changed the way transformations are printed.
2020-09-14 12:47:28 +03:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
Peter Hawkins
a169743f64
Enable s8/u8/s16/u16 types on TPU in tests. (#4032) 2020-08-12 10:02:35 -04:00
Jake Vanderplas
0cbb4279ee
Cleanup: make skip_if_unsupported_type more robust (#3912) 2020-07-30 11:07:56 -07:00
David Majnemer
33faf6a46e
TPUs support half precision arithmetic (#3878)
* TPUs support half precision arithmetic

* update jax2tf tests to handle fp16

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-07-28 18:07:38 -07:00
Peter Hawkins
f6221a663e
Enable int{8,16} and uint{8,16} tests in lax_test and lax_numpy_test. (#3833) 2020-07-23 16:17:55 -04:00
Peter Hawkins
a6e2d20b31
Add support for base dilation and window dilation to reduce window op… (#3803) 2020-07-20 17:27:24 -04:00
Peter Hawkins
e4d5eade54 Use iteration over equations to test for "transpose" and "broadcast". 2020-07-17 08:44:47 -04:00
Jake Vanderplas
6b471e2ac6
Cleanup: define type lists in test_util & use in several test files. (#3616) 2020-07-07 17:01:38 -07:00
Matthew Johnson
ba1b5ce8de
skip some ode tests on gpu for speed (#3629) 2020-07-01 11:26:44 -07:00
Jake Vanderplas
09d128edb3
Cleanup: remove some test interdependence (#3600) 2020-06-29 16:22:05 -07:00
Jake VanderPlas
e9aac7bbee Use re.search and include test class name 2020-06-29 11:08:57 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
32e419d189
Fix eigh JVP to ensure that both the primal and tangents of the eigen… (#3550)
* Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real.

Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types.

* Cast input to static indexing grad tests to a JAX array so new type check passes.
2020-06-25 08:14:54 -04:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Peter Hawkins
972c7fda67
Fix bug where jnp.array returned a classic NumPy array, sometimes wit… (#3283)
* Fix bug where jnp.array returned a classic NumPy array, sometimes with the wrong type.

Unconditionally calls `device_put`, because `lax.convert_element_type` has a fast path that sometimes fails to lead to a `device_put`.

Improve the test for `jnp.array` and its test harness.
2020-06-01 19:29:26 -04:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bb2127cebd
Future-proof view test against signaling NaNs (#3178) 2020-05-21 09:20:59 -07:00
Jake Vanderplas
6e3c8b1d9b
Fix arr.view() on TPU & improve tests (#3141) 2020-05-21 06:40:24 -07:00
Jake Vanderplas
e675f804ff
Add support for 8- and 16-bit output in _random_bits (#3090) 2020-05-15 19:09:43 -07:00
Peter Hawkins
22d14fd7dd
Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080) 2020-05-13 14:00:44 -04:00
joao guilherme
d2f84d635b
Change instances of onp to np and np to jnp (#3044) 2020-05-12 20:37:05 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
Peter Hawkins
9174684253
Cache test_utils.format_shape_and_dtype_string. (#2959)
A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.)
2020-05-04 21:08:34 -04:00
Peter Hawkins
d61d6f44dc
Fix a number of flaky tests. (#2953)
* relax some test tolerances.
* disable 'random' preconditioner in CG test (#2951).
* ensure that scatter and top-k tests don't create ties.
2020-05-04 14:34:08 -04:00
Peter Hawkins
9802d7321c
Update XLA. (#2927) 2020-05-01 21:08:56 -04:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
Anselm Levskaya
dddad2a3dc Add top_k jvp and batching rules 2020-04-28 07:19:58 -07:00
Matthew Johnson
2d25773c21 add custom_jvp for logaddexp / logaddexp2
fixes #2107, draws from #2356 and #2357, thanks @yingted !

Co-authored-by: Ted Ying <yingted@gmail.com>
2020-04-13 11:20:16 -07:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Stephan Hoyer
dd92a03713
Docstring for test_util.check_grads (#2656)
Fixes https://github.com/google/jax/issues/2648
2020-04-09 10:18:07 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00