144 Commits

Author SHA1 Message Date
George Necula
1e84cbe9cc
[jax2tf] Fix random.split when jax_exable_x64 (#4208)
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
2020-09-07 14:41:50 +03:00
George Necula
5eac47726b
[jax2tf] Implementation of random_gamma (#4192)
* [jax2tf] implementation of random_gamma

The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.

On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
2020-09-03 14:18:35 +03:00
George Necula
417c9ff764
Fix pytype error (#4158) 2020-08-27 09:41:16 +03:00
Matthew Johnson
1d93991003
allow random.choice to accept ndarray input (#4145)
* allow random.choice to accept ndarray `a`

follow-up to #4137 to allow ndarray inputs to be passed

* add jax.random.choice tests to cover ndarray input

* don't use callables in test params

it can mess with pytest-xdist because of hashing by id
2020-08-26 10:21:56 -07:00
Jake Vanderplas
6d54eb563e
Do not call asarray() on inputs of jax.random.choice (#4137) 2020-08-25 05:47:43 -07:00
Matthew Johnson
56b3688db9 make random.choice error when shape isn't sequence
fixes #4124
2020-08-21 19:58:06 -07:00
Mihaela Rosca
1e8ac24863
Add rademacher, maxwell, double_sided_maxwell and weibull_min to jax.random. (#4104) 2020-08-20 07:46:55 -07:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
Ethan Luo Yicheng
6e4ec7cb81
Fix broadcasting in random.uniform and randint. (#4035) 2020-08-12 11:52:42 -07:00
Scott Linderman
ea88c55f55
Fixes and tests for jax.random.multivariate_normal (#4002)
* Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices.  It works as long as mean and covariance are broadcast-compatible, as specified in the docstring.

* Fix bug in multivariate_normal shape checking

Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix.

* Add test for multivariate_normal shapes

This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes.

* Fix linter issues in tests/random_test.py

Trimming trialing whitespace and 80 char limit.

* Really trimming whitespace in tests/random_test.py

Arg. Have to fix my editor to do this automatically.
2020-08-09 11:32:45 -07:00
John Aslanides
038c85dad0
Improve type annotations for jit and vmap. (#3938) 2020-08-08 12:22:54 -04:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake Vanderplas
ee7f035349
jax.random: use correct x32/x64 default dtypes. (#3841)
This is a no-op in the current package, but will make things cleaner during the x64 deprecation.
2020-07-26 08:58:37 -07:00
bion howard
74d363e552
fix extremely minor typo (#3815)
"ijnputs" -> "inputs"
2020-07-21 12:41:08 -07:00
Matthew Johnson
74ee2ef6eb
avoid value-based error check in random.choice (#3531) 2020-06-23 14:03:36 -07:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Jake Vanderplas
19f308b9ed
implement jax.random.choice (#3463) 2020-06-19 16:04:42 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
fehiepsi
b680c994ae allow scalar input in poisson sampler 2020-06-12 01:42:25 -04:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Matthew Johnson
9c0a58a8e7
add float dtype checks to random.py (#3320)
fixes #3317
2020-06-04 10:13:15 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Roy Frostig
e7e4cbce5d docstring fix 2020-05-29 16:00:20 -07:00
Roy Frostig
8fbab04d4a codeblock for example usage in PRNG docstring 2020-05-29 15:41:28 -07:00
Roy Frostig
657cf1bb19 render example usage from PRNG doc 2020-05-29 15:11:00 -07:00
Jean-Baptiste Lespiau
a486f54814
Add a summary explaining the usage and context for JAX PRNG design. (#2525)
* Add a summary explaining the usage and context for JAX PRNG design.

The current design_notes do not match current JAX API, and it's a pretty
long doc to read to understand how to use it.

Closes: #2087

* Change 'should' to be more precise.

* Address comments.
2020-05-26 10:38:28 +03:00
joao guilherme
77e4d8b3b9
Updates onp -> np in random, loops, jet and in the tests of stax and optix (#3182) 2020-05-21 14:12:18 -07:00
Jake Vanderplas
8fe26190de
Expand type support for random uniform() & randint() (#3138) 2020-05-19 14:19:00 -07:00
Jake Vanderplas
e675f804ff
Add support for 8- and 16-bit output in _random_bits (#3090) 2020-05-15 19:09:43 -07:00
Jake Vanderplas
12a9af869f
Update random.logistic() to prevent infinities (#3048) 2020-05-15 14:29:02 -07:00
Tom Hennigan
abdf504e9e
Avoid recompilation of rolled loops in threefry2x32. (#3069) 2020-05-12 15:03:22 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Stephan Hoyer
46ce80b032
jax.random.poisson (#2805)
* jax.random.poisson

The implementation for lam < 10 was directly copied from TensorFlow probability:
https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155

I adapted the implementation for lam > 10 from TensorFlow:
https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc

The methods themselves match both TensorFlow and NumPy:
https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574

* add a check for even larger lambda

* increment iter count

* remove comment that makes no sense

* Fix chi-squared tests in random_test.py

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* Fix accept condition (based on correct chi-squared test)

* Add moment checks for Poisson

* Add batching test, more Poisson rates
2020-05-02 11:24:59 -04:00
Jake Vanderplas
6425ca2aed
Merge pull request #2925 from jakevdp/shuffle
Deprecate random.shuffle() and implement random.permutation() for multi-dim inputs
2020-05-02 06:32:50 -07:00
Jake VanderPlas
d8d71407dc Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices. 2020-05-01 15:18:24 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
MichaelMarien
e0d42e90eb
Feature/permutation (#1568)
* added test for random.permutation

* added permutation that wraps shuffle with behaviour of np.random.permutation

* update docstring

* need to shuffle also the integer range input

* fixed test for permutation with integer

* tweak handling of random.permutation scalar case

* NotImplementedError for random.permutation on >1d

pending resolution to #2066

* address reviewer comments: improve tests

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-23 22:40:33 -07:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
William C Grisaitis
2bc3c7985e
Fix distribution name in docstring (#2764) 2020-04-20 19:55:23 -07:00
John Aslanides
6cb69f5c46 Fix type annotation for uniform. 2020-04-13 13:24:08 +01:00
John Aslanides
c06fe56fc5 Add some types to jax.random and jnp.ndarray. 2020-04-12 09:14:54 +01:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Colin
6d0d6fd6c7
Docstring typo (#2228) 2020-02-14 08:04:20 -08:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00