157 Commits

Author SHA1 Message Date
Jake VanderPlas
18228f45cb PRNGSeed: ensure JIT invariance for valid inputs. 2020-11-12 09:06:24 -08:00
Matthew Johnson
9ba28d2634 Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338670328
2020-10-23 07:35:01 -07:00
Matthew Johnson
fcaced32aa Copybara import of the project:
--
ced333d1d4aec2825e9afd81c2ca9721b7e3cc67 by Matthew Johnson <mattjj@google.com>:

redo #4535 lazy simplification

PiperOrigin-RevId: 338606348
2020-10-22 21:18:22 -07:00
Matthew Johnson
ced333d1d4 redo #4535 lazy simplification 2020-10-22 16:56:29 -07:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00
Matthew Johnson
f553ed24e1 Temporary rollback of #4535 pending a possible XLA bug it exposed in an internal test.
PiperOrigin-RevId: 337219426
2020-10-14 18:52:13 -07:00
jax authors
fb01f59020 Merge pull request #4535 from google:lazy-simplification
PiperOrigin-RevId: 337183224
2020-10-14 15:16:50 -07:00
Matthew Johnson
990dc57deb Merge remote-tracking branch 'origin/master' into lazy-simplification 2020-10-14 14:52:16 -07:00
Peter Hawkins
1d4c53cef9 Fix CUDA launch error when generating an empty PRNG array. 2020-10-14 14:33:56 -04:00
Peter Hawkins
0b8eb92d59 Add stop_gradients around lax.nextafter to fix TFP gradient errors for jax.random.truncated_normal. 2020-10-13 09:16:29 -04:00
Peter Hawkins
080007ab82 Ensure values returned by jax.random.truncated_normal() are in range.
A user observed -inf values being returned by truncated_normal(), which occur if the uniform random value passed to erfinv() is out of range, e.g., due to rounding. Do more of the computation using jax.random.uniform(), which promises correct behavior in the face of rounding.

As an added security measure, also clamp the outputs of the function to the open interval.
2020-10-12 16:27:54 -04:00
Matthew Johnson
4e65a6f0a9 don't generate lazy iota/eye/tri/delta omnistaging 2020-10-10 21:08:52 -07:00
Matthew Johnson
c42d736e34 remove limit on size of random arrays 2020-09-23 19:37:34 -07:00
George Necula
1e84cbe9cc
[jax2tf] Fix random.split when jax_exable_x64 (#4208)
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
2020-09-07 14:41:50 +03:00
George Necula
5eac47726b
[jax2tf] Implementation of random_gamma (#4192)
* [jax2tf] implementation of random_gamma

The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.

On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
2020-09-03 14:18:35 +03:00
George Necula
417c9ff764
Fix pytype error (#4158) 2020-08-27 09:41:16 +03:00
Matthew Johnson
1d93991003
allow random.choice to accept ndarray input (#4145)
* allow random.choice to accept ndarray `a`

follow-up to #4137 to allow ndarray inputs to be passed

* add jax.random.choice tests to cover ndarray input

* don't use callables in test params

it can mess with pytest-xdist because of hashing by id
2020-08-26 10:21:56 -07:00
Jake Vanderplas
6d54eb563e
Do not call asarray() on inputs of jax.random.choice (#4137) 2020-08-25 05:47:43 -07:00
Matthew Johnson
56b3688db9 make random.choice error when shape isn't sequence
fixes #4124
2020-08-21 19:58:06 -07:00
Mihaela Rosca
1e8ac24863
Add rademacher, maxwell, double_sided_maxwell and weibull_min to jax.random. (#4104) 2020-08-20 07:46:55 -07:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
Ethan Luo Yicheng
6e4ec7cb81
Fix broadcasting in random.uniform and randint. (#4035) 2020-08-12 11:52:42 -07:00
Scott Linderman
ea88c55f55
Fixes and tests for jax.random.multivariate_normal (#4002)
* Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices.  It works as long as mean and covariance are broadcast-compatible, as specified in the docstring.

* Fix bug in multivariate_normal shape checking

Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix.

* Add test for multivariate_normal shapes

This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes.

* Fix linter issues in tests/random_test.py

Trimming trialing whitespace and 80 char limit.

* Really trimming whitespace in tests/random_test.py

Arg. Have to fix my editor to do this automatically.
2020-08-09 11:32:45 -07:00
John Aslanides
038c85dad0
Improve type annotations for jit and vmap. (#3938) 2020-08-08 12:22:54 -04:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake Vanderplas
ee7f035349
jax.random: use correct x32/x64 default dtypes. (#3841)
This is a no-op in the current package, but will make things cleaner during the x64 deprecation.
2020-07-26 08:58:37 -07:00
bion howard
74d363e552
fix extremely minor typo (#3815)
"ijnputs" -> "inputs"
2020-07-21 12:41:08 -07:00
Matthew Johnson
74ee2ef6eb
avoid value-based error check in random.choice (#3531) 2020-06-23 14:03:36 -07:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Jake Vanderplas
19f308b9ed
implement jax.random.choice (#3463) 2020-06-19 16:04:42 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
fehiepsi
b680c994ae allow scalar input in poisson sampler 2020-06-12 01:42:25 -04:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Matthew Johnson
9c0a58a8e7
add float dtype checks to random.py (#3320)
fixes #3317
2020-06-04 10:13:15 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Roy Frostig
e7e4cbce5d docstring fix 2020-05-29 16:00:20 -07:00
Roy Frostig
8fbab04d4a codeblock for example usage in PRNG docstring 2020-05-29 15:41:28 -07:00
Roy Frostig
657cf1bb19 render example usage from PRNG doc 2020-05-29 15:11:00 -07:00
Jean-Baptiste Lespiau
a486f54814
Add a summary explaining the usage and context for JAX PRNG design. (#2525)
* Add a summary explaining the usage and context for JAX PRNG design.

The current design_notes do not match current JAX API, and it's a pretty
long doc to read to understand how to use it.

Closes: #2087

* Change 'should' to be more precise.

* Address comments.
2020-05-26 10:38:28 +03:00
joao guilherme
77e4d8b3b9
Updates onp -> np in random, loops, jet and in the tests of stax and optix (#3182) 2020-05-21 14:12:18 -07:00
Jake Vanderplas
8fe26190de
Expand type support for random uniform() & randint() (#3138) 2020-05-19 14:19:00 -07:00
Jake Vanderplas
e675f804ff
Add support for 8- and 16-bit output in _random_bits (#3090) 2020-05-15 19:09:43 -07:00
Jake Vanderplas
12a9af869f
Update random.logistic() to prevent infinities (#3048) 2020-05-15 14:29:02 -07:00
Tom Hennigan
abdf504e9e
Avoid recompilation of rolled loops in threefry2x32. (#3069) 2020-05-12 15:03:22 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Stephan Hoyer
46ce80b032
jax.random.poisson (#2805)
* jax.random.poisson

The implementation for lam < 10 was directly copied from TensorFlow probability:
https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155

I adapted the implementation for lam > 10 from TensorFlow:
https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc

The methods themselves match both TensorFlow and NumPy:
https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574

* add a check for even larger lambda

* increment iter count

* remove comment that makes no sense

* Fix chi-squared tests in random_test.py

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* Fix accept condition (based on correct chi-squared test)

* Add moment checks for Poisson

* Add batching test, more Poisson rates
2020-05-02 11:24:59 -04:00
Jake Vanderplas
6425ca2aed
Merge pull request #2925 from jakevdp/shuffle
Deprecate random.shuffle() and implement random.permutation() for multi-dim inputs
2020-05-02 06:32:50 -07:00
Jake VanderPlas
d8d71407dc Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices. 2020-05-01 15:18:24 -07:00