356 Commits

Author SHA1 Message Date
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
jax authors
ae46b7564e Merge pull request #24593 from froystig:random-dtypes
PiperOrigin-RevId: 698268678
2024-11-19 23:04:06 -08:00
Roy Frostig
4bb81075bc represent random.key_impl of builtin RNGs by canonical string name
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
d6394c0795 random.key_impl: improve repr of output 2024-09-04 10:10:31 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Matthew Johnson
83dfed1c02 make core.as_named_shape treat int like tuple[int]
fixes #21343
2024-07-20 17:14:38 +00:00
Justin Fu
9439f63645 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
Ruturaj4
10d7827966 [ROCm] Add hip specific checks in threefry test 2024-06-03 23:08:34 -05:00
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Peter Hawkins
de455e7003 Fix small bug in random_test.
unsafe_buffer_pointer() and on_device_size_in_bytes() are methods, not properties, so presumably the test intended to call them rather than test equality of the bound methods.

PiperOrigin-RevId: 614651090
2024-03-11 07:04:58 -07:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Jake VanderPlas
2a775faf15 Register jax.Array device method deprecation 2024-02-09 11:18:19 -08:00
Roy Frostig
2478f311d3 remove key array's isinstance-overriding metaclass
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.

PiperOrigin-RevId: 601815616
2024-01-26 11:16:56 -08:00
Jake VanderPlas
b0ed801661 random_test: test random.bits directly 2024-01-16 13:19:29 -08:00
Peter Hawkins
8803774647 Fix/disable two tests failing in Windows CI.
PiperOrigin-RevId: 595793411
2024-01-04 13:47:32 -08:00
Matthew Johnson
05da18ab54 tweaks to enable adding custom tangent dtypes
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other

this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
2023-12-22 11:33:14 -08:00
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
jax authors
196c97fa0c Merge pull request #18949 from froystig:seed-offset
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
c0f3fa00f8 [random] support key dtype in custom_jvp
To do this, we introduce a dtype for key tangents which cannot be used
to generate random values
2023-11-10 11:16:23 -08:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
Jake VanderPlas
8f82f2e66f [typing] regularize types of jax.random API 2023-10-20 10:33:20 -07:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
Jake VanderPlas
0da4be5e2a [random] make PRNG impl attributes private 2023-10-18 11:10:47 -07:00
jax authors
2be6019f1c Rollback to fix internal breakage
Reverts 7d203aebfa6206affde207c884b50172e203d177

PiperOrigin-RevId: 574101804
2023-10-17 04:24:15 -07:00
Jake VanderPlas
a2623f2888 [random] Avoid references to PRNGKeyArray type
See https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
2023-10-13 11:10:05 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Jake VanderPlas
b407620d1d random_test: fix deprecation warnings for key tests
Some versions of numpy on some platforms raise warnings when custom PRNG keys
are passed to np.assert_array_equal. Address this by creating a specific function
for comparing key values.
2023-10-06 13:21:23 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
3d503e01dc random_test: remove unnecessary test utilities 2023-10-05 15:33:14 -07:00
jax authors
8f911e1512 random_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 571146766
2023-10-05 15:28:51 -07:00
Jake VanderPlas
f739a888f3 jax.random: fix NaN corner-case in loggamma 2023-10-04 11:40:32 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
jax authors
46f44287bc Merge pull request #17766 from jakevdp:prng-itemsize
PiperOrigin-RevId: 568381057
2023-09-25 18:30:59 -07:00