103 Commits

Author SHA1 Message Date
Matthew Johnson
71f5f9972c skip checks in big randomness test 2020-09-23 20:15:32 -07:00
Matthew Johnson
96f5a3c402 fix test for non-omnistaging 2020-09-23 19:39:22 -07:00
Matthew Johnson
c42d736e34 remove limit on size of random arrays 2020-09-23 19:37:34 -07:00
Matthew Johnson
1d93991003
allow random.choice to accept ndarray input (#4145)
* allow random.choice to accept ndarray `a`

follow-up to #4137 to allow ndarray inputs to be passed

* add jax.random.choice tests to cover ndarray input

* don't use callables in test params

it can mess with pytest-xdist because of hashing by id
2020-08-26 10:21:56 -07:00
Matthew Johnson
56b3688db9 make random.choice error when shape isn't sequence
fixes #4124
2020-08-21 19:58:06 -07:00
Mihaela Rosca
1e8ac24863
Add rademacher, maxwell, double_sided_maxwell and weibull_min to jax.random. (#4104) 2020-08-20 07:46:55 -07:00
Ethan Luo Yicheng
6e4ec7cb81
Fix broadcasting in random.uniform and randint. (#4035) 2020-08-12 11:52:42 -07:00
David Majnemer
265c3faa40
Remove type restrictions (#4011)
We support s8, u8, s16, u16, half16 on TPU
2020-08-10 18:33:15 -07:00
Scott Linderman
ea88c55f55
Fixes and tests for jax.random.multivariate_normal (#4002)
* Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices.  It works as long as mean and covariance are broadcast-compatible, as specified in the docstring.

* Fix bug in multivariate_normal shape checking

Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix.

* Add test for multivariate_normal shapes

This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes.

* Fix linter issues in tests/random_test.py

Trimming trialing whitespace and 80 char limit.

* Really trimming whitespace in tests/random_test.py

Arg. Have to fix my editor to do this automatically.
2020-08-09 11:32:45 -07:00
Jake Vanderplas
5ae6043a5f
Cleanup test names in random_test.py (#3842) 2020-07-24 09:36:16 -07:00
Jake Vanderplas
6b471e2ac6
Cleanup: define type lists in test_util & use in several test files. (#3616) 2020-07-07 17:01:38 -07:00
Matthew Johnson
65c4d755de
fix bug in categorical test, disable #3611 on tpu (#3633)
* fix bug in categorical test, disable #3611 on tpu

Disabling #3611 on TPU pending a TPU compilation bug.

* unskip a test
2020-07-01 14:15:48 -07:00
Matthew Johnson
1bd04e2d2f
skip gamma tests on tpu for compilation speed (#3631) 2020-07-01 12:22:39 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
19f308b9ed
implement jax.random.choice (#3463) 2020-06-19 16:04:42 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
Jake Vanderplas
0da7b4d1bd
Improve dtype test coverage for random_test (#3254) 2020-06-18 15:17:13 -07:00
Matthew Johnson
9c0a58a8e7
add float dtype checks to random.py (#3320)
fixes #3317
2020-06-04 10:13:15 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
joao guilherme
e48a4e012b
uses np.prod instead of jnp.prod for shapes (#3236) 2020-05-28 13:16:56 -07:00
Jake Vanderplas
b7ff305a20
fix broken TPU test (#3153) 2020-05-19 15:50:54 -07:00
Jake Vanderplas
8fe26190de
Expand type support for random uniform() & randint() (#3138) 2020-05-19 14:19:00 -07:00
Jake Vanderplas
e675f804ff
Add support for 8- and 16-bit output in _random_bits (#3090) 2020-05-15 19:09:43 -07:00
Peter Hawkins
22d14fd7dd
Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080) 2020-05-13 14:00:44 -04:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Julius Kunze
e4d8cacfc6
Fix tests for random.categorical with multi-dimensional logits (#2955) 2020-05-04 20:12:43 -07:00
Stephan Hoyer
46ce80b032
jax.random.poisson (#2805)
* jax.random.poisson

The implementation for lam < 10 was directly copied from TensorFlow probability:
https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155

I adapted the implementation for lam > 10 from TensorFlow:
https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc

The methods themselves match both TensorFlow and NumPy:
https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574

* add a check for even larger lambda

* increment iter count

* remove comment that makes no sense

* Fix chi-squared tests in random_test.py

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* Fix accept condition (based on correct chi-squared test)

* Add moment checks for Poisson

* Add batching test, more Poisson rates
2020-05-02 11:24:59 -04:00
Jake VanderPlas
d8d71407dc Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices. 2020-05-01 15:18:24 -07:00
Stephan Hoyer
e6df98de55
Fix chi-squared tests in random_test.py (#2847)
As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.
2020-04-27 17:24:39 -07:00
MichaelMarien
e0d42e90eb
Feature/permutation (#1568)
* added test for random.permutation

* added permutation that wraps shuffle with behaviour of np.random.permutation

* update docstring

* need to shuffle also the integer range input

* fixed test for permutation with integer

* tweak handling of random.permutation scalar case

* NotImplementedError for random.permutation on >1d

pending resolution to #2066

* address reviewer comments: improve tests

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-23 22:40:33 -07:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Matthew Johnson
86a4073a75
enable beta test on float64 values (#1177)
* enable beta test on float64 values

cf. #1123

* Enable beta test on all platforms.

It seems sufficiently fast now.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-04-01 11:25:32 -04:00
Jamie Townsend
c999a482b0 Test for PRNG consistency accross JAX versions 2020-02-11 13:19:42 +00:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
Peter Hawkins
0b1d2fc3d1
Avoid accidental type promotion in gamma sampler gradient. (#2150)
Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.
2020-02-03 12:44:46 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Matthew Johnson
00be20bdfa
Merge pull request #1855 from JuliusKunze/categorical
Add categorical sampler
2020-01-10 07:59:21 -08:00
Julius Kunze
f36d858c4e Require shape = sample_shape + batch_shape in random.categorical 2020-01-10 13:28:03 +00:00
fehiepsi
cdfa57dfcc merge master 2019-12-23 22:52:15 -05:00
Matthew Johnson
7175c1dfe1 fix transpose bug in multivariate normal, add test
fixes #1869
2019-12-17 15:08:08 -08:00
Julius Kunze
6178755281 Remove safe zip/map 2019-12-13 15:00:45 +00:00
Julius Kunze
9d12a24b63 Add categorical sampler 2019-12-13 12:41:26 +00:00
fehiepsi
7ec2ac58ca not use custom transform for gamma sampler 2019-12-01 09:44:45 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Matthew Johnson
b358c27c92 replace x.shape with onp.shape(x) in random.py
fixes #1748 (thanks @vitchyr)
2019-11-22 10:59:31 -08:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00