96 Commits

Author SHA1 Message Date
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
Peter Hawkins
0b1d2fc3d1
Avoid accidental type promotion in gamma sampler gradient. (#2150)
Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.
2020-02-03 12:44:46 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
00be20bdfa
Merge pull request #1855 from JuliusKunze/categorical
Add categorical sampler
2020-01-10 07:59:21 -08:00
Julius Kunze
f36d858c4e Require shape = sample_shape + batch_shape in random.categorical 2020-01-10 13:28:03 +00:00
fehiepsi
edf0e61bc9 support nested vmap for gamma sampler 2019-12-26 22:43:06 -05:00
fehiepsi
c75bf4ab72 make beta sampler faster 2019-12-23 23:02:08 -05:00
fehiepsi
cdfa57dfcc merge master 2019-12-23 22:52:15 -05:00
Matthew Johnson
7175c1dfe1 fix transpose bug in multivariate normal, add test
fixes #1869
2019-12-17 15:08:08 -08:00
Julius Kunze
698327080b Clarify documentation 2019-12-15 15:43:39 +00:00
Julius Kunze
9d12a24b63 Add categorical sampler 2019-12-13 12:41:26 +00:00
fehiepsi
7ec2ac58ca not use custom transform for gamma sampler 2019-12-01 09:44:45 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Matthew Johnson
b358c27c92 replace x.shape with onp.shape(x) in random.py
fixes #1748 (thanks @vitchyr)
2019-11-22 10:59:31 -08:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
Matthew Johnson
1f4e45cdcd tweak shape convention again 2019-10-20 21:14:48 +00:00
Matthew Johnson
ab6ac6c876 standardize shape handling in jax.random 2019-10-20 19:37:37 +00:00
Matthew Johnson
8a132b4109 try simplifying random.multivariate_normal api 2019-10-19 22:27:43 +00:00
James Bradbury
74dc4bf72a
Merge pull request #1419 from aeftimia/multivariate-normal
Multivariate normal distribution support
2019-10-15 16:46:13 -07:00
Alex Eftimiades
2f4e88760c fix indentation 2019-10-15 14:35:40 -04:00
Alex Eftimiades
7939bad65a Merge branch 'multivariate-normal' of https://github.com/aeftimia/jax into multivariate-normal 2019-10-12 11:38:06 -04:00
Alex Eftimiades
ed1bd5fccd test cases for ValueError with wrong shape; one test case with negative mean 2019-10-10 09:33:28 -04:00
Alex Eftimiades
d344c6f5e4
Update jax/random.py
Fix typo

Co-Authored-By: James Bradbury <jekbradbury@google.com>
2019-10-04 09:57:17 -04:00
Alex Eftimiades
c1057f77ae minor readability cleanup 2019-10-02 11:39:49 -04:00
Alex Eftimiades
ea67fa8f09 fix failing tests with full covariance matrix 2019-10-02 11:08:17 -04:00
Alex Eftimiades
ad3f25b02f small cleanup and retrigger tests 2019-10-02 09:25:55 -04:00
Trevor Cai
a44c1caf21 Use finfo(dtype).tiny as uniform minval
Otherwise, using the default clopen uniform for truncated_normal
introduces a slight shift of the empirical mean.
2019-10-01 22:28:31 +01:00
Alex Eftimiades
b2b10a3c96 Multivariate normal support
Tested with inverse transform and KS test
2019-10-01 16:49:57 -04:00
Alex Eftimiades
b70a7800d7 multivariate normal support 2019-09-23 20:21:28 -04:00
James Bradbury
146b5d121a
Merge pull request #1262 from google/jb/initializers
Migrate initializers and activation functions to jax.nn
2019-09-04 14:48:20 -07:00
James Bradbury
bf28c44ada address comments 2019-09-03 17:51:29 -07:00
Matthew Johnson
2815c55bb1 switch rolled/unrolled loops in prng hash 2019-08-31 07:35:37 -07:00
James Bradbury
f4aeb363a9 add truncated normal 2019-08-29 18:14:57 -07:00
David Majnemer
1dbdaab765 Use log1p when computing log(1 + x) or log(1 - x)
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Peter Hawkins
719e17ba8e Avoid dynamic slicing in threefry implementation.
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Matthew Johnson
75bb38e741 address reviewer comments, no op-by-op in threefry 2019-08-13 11:30:24 -07:00
Matthew Johnson
d857d3f595 improve prng compile times with outer-loop rolling 2019-08-13 09:43:44 -07:00
Matthew Johnson
275bf9da6d improve prng compile times with loop rolling
cf. #1172
2019-08-13 07:20:09 -07:00
Jamie Townsend
f351dedbb7 Add logistic distribution to jax.random 2019-08-06 12:19:05 +01:00
Matthew Johnson
ec456a181b
Update random.py
Update `jax.random.fold_in` docstring to specify  `data` is treated as a 32bit integer.
2019-07-23 12:21:28 +03:00
Peter Zhokhov
8985d684a1 remove static argnums from random.fold_in 2019-07-19 12:04:33 -07:00
David Majnemer
5607e46922 Avoid generating non-finite values from cauchy
If uniform generates 0, -pi/2 will be sent to tan resulting in a
non-finite result. Instead, generate values on (0,1).
2019-07-06 21:11:02 -07:00
David Majnemer
52fa63af48 Avoid generating non-finite values from gumbel and laplace
In the case of gumbel, we take the log(-log(x)), as such we would not want to let x be 0 or 1 as we would get a non-finite number.

In the case of laplace, we take the log1p(-abs(x)), as such we would not want to let x be -1 or 1 as we would get a non-finite number.

This was found by inspection, I have no evidence that this happens in practice.
2019-07-05 11:01:28 -07:00
fehiepsi
dc91d00afd use split 2 instead of split 1 2019-06-27 17:28:36 -04:00
fehiepsi
d284388577 pass tests locally 2019-06-27 13:07:26 -04:00
fehiepsi
907925fb30 improve compling time of gamma 2019-06-27 13:07:26 -04:00
fehiepsi
9252d596eb implement gamma grad 2019-06-27 13:07:25 -04:00