224 Commits

Author SHA1 Message Date
Rahul Batra
3391a5e385 [ROCm]: Disable some tests on ROCm platform 2022-12-19 21:33:13 +00:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Jake VanderPlas
7e3f6748ec random_test: skip singular covariance test on accelerators 2022-12-06 09:26:23 -08:00
Jake VanderPlas
58d6a3b164 random.multivariate_normal: add note about singular covariance 2022-12-05 12:43:05 -08:00
Jake VanderPlas
fdf5894c75 [x64] make random_test more type-safe 2022-12-01 15:51:37 -08:00
Roy Frostig
671e91d02d reduce relative tolerance in small-alpha Dirichlet test 2022-11-22 14:10:14 -08:00
Roy Frostig
f8ecab8f9a fix Threefry split/fold_in symmetry test under key arrays mode 2022-11-22 09:59:13 -08:00
Roy Frostig
a412d27519 test threefry split consistency with vmapped fold_in of lax.axis_index 2022-11-21 15:24:48 -08:00
Roy Frostig
dab2909a31 make threefry split and fold_in symmetric
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.

This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
2022-11-21 15:24:48 -08:00
Peter Hawkins
99e1c3dd66 [JAX] Opt into high precision matrix multiplications in JAX tests that fail on A100.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.

Fixes https://github.com/google/jax/issues/12008

PiperOrigin-RevId: 488749199
2022-11-15 13:50:21 -08:00
Patrick Kidger
d2afa84a6e PRNGKeyArray is now a virtual subclass of ndarray 2022-11-11 08:04:38 -08:00
Matthew Johnson
190204ff7d fix jax.random.logits shape argument
fixes #13124
2022-11-04 19:51:39 -07:00
Matthew Johnson
213d2c8592 integrate new (partitionable, count-space-exhaustive) counts generation 2022-10-29 00:05:49 -07:00
Jake VanderPlas
077294555d Fix future numpy warning 2022-10-13 11:14:21 -07:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
jax authors
bf7525e121 Merge pull request #12170 from froystig:just-dtype
PiperOrigin-RevId: 471409020
2022-08-31 18:36:47 -07:00
Roy Frostig
0d3630b349 add key_data to jax.random for key array unwrapping
This is often useful in testing and debugging. Its more dangerous
inverse, wrapping, remains internal only.
2022-08-31 09:23:11 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
Roy Frostig
077bfac544 add dtype property to key arrays 2022-08-30 14:06:01 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Roy Frostig
af1871450c test jax.eval_shape with key array inputs/outputs 2022-08-24 13:13:17 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
82243d06fc enable several non-threefry RNG tests without config.jax_enable_custom_prng 2022-08-18 21:46:55 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Jake VanderPlas
bcb45557c4 random.choice: make compatible with strict promotion 2022-06-28 10:41:30 -07:00
Peter Hawkins
fc659d5308 Reduce size of double-sided maxwell random test.
It appears that for some inputs this triggers an integer overflow in scipy.stats.maxwell().cdf.
2022-06-24 12:01:20 -04:00
Jake VanderPlas
eec1225d74 TST: skip tests on numpy 1.23.0 due to regressions in that release 2022-06-23 11:46:51 -07:00
George Necula
4b03ebf4f5 Fix overflow of large prng computation
Fixes: #11010
2022-06-20 10:48:15 +02:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
Jake VanderPlas
2b6387f83f [x64] make jax.random compatible with jax_numpy_dtype_promotion=strict 2022-05-27 11:12:39 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Rohit Santhanam
8d9f17df19 Disabled one and enabled several unit tests for ROCm. 2022-05-10 19:47:26 +00:00
jax authors
227e525de2 Merge pull request #10458 from carlosgmartin:random_orthogonal_unitary
PiperOrigin-RevId: 445522278
2022-04-29 15:40:16 -07:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
YouJiacheng
4ff6b1fbca Fix PRNGKeyArray.broadcast_to with scalar shape 2022-04-16 00:27:30 +08:00
Joan Puigcerver
0c02f7935a Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Jake VanderPlas
e68e87ac22 random_test: add tests of random values for distributions 2022-03-23 10:34:26 -07:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Joan Puigcerver
caf094d06b Support gamma distribution with PRNGKeys other than threefry2x32.
PiperOrigin-RevId: 433614014
2022-03-09 17:06:02 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Jake VanderPlas
2c2773a5f1 jax.random.poisson: fix corner cases 2022-02-28 12:10:47 -08:00
Roy Frostig
88c6b84daf in tests, compare jnp operations on PRNGKeyArrays to the same on jnp arrays 2022-02-17 11:43:09 -08:00