206 Commits

Author SHA1 Message Date
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Roy Frostig
69878c4924 remove Threefry GPU kernel
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.

PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
Jake VanderPlas
7e60331cd1 [key reuse] print information about key reuse location 2024-03-21 13:34:26 -07:00
Yash Katariya
ab2e906323 Fix the indentation of the physical_hlo_sharding function
PiperOrigin-RevId: 616280971
2024-03-15 16:59:20 -07:00
Yash Katariya
cd1e55a351 Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 616267810
2024-03-15 16:02:13 -07:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Yash Katariya
550ce44afd Move the replicated trailing dims check inside logical_op_sharding
PiperOrigin-RevId: 611277405
2024-02-28 17:03:37 -08:00
Yash Katariya
217f08236e Allow sharding propagation to input for prng keys whose sharding is not specified.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.

PiperOrigin-RevId: 611246986
2024-02-28 15:22:16 -08:00
Yash Katariya
2f7c36c763 Contrain the trailing dims of prng key array to REPLICATED and keep other dims as unconstrained.
PiperOrigin-RevId: 611232967
2024-02-28 14:38:43 -08:00
Jake VanderPlas
49eb7008c0 Define reuse_key primitive in jax._src.prng 2024-02-14 14:01:08 -08:00
Roy Frostig
a04332504b remove PRNGKeyArray ABC
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
2024-01-29 12:41:26 -08:00
Roy Frostig
2478f311d3 remove key array's isinstance-overriding metaclass
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.

PiperOrigin-RevId: 601815616
2024-01-26 11:16:56 -08:00
Jake VanderPlas
78f27dfa9d Remove unnecessary Array.register 2024-01-24 14:59:25 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Jake VanderPlas
fff5ea579a Remove deprecated unsafe_raw_array method from PRNG keys
PiperOrigin-RevId: 595190146
2024-01-02 13:03:21 -08:00
Matthew Johnson
05da18ab54 tweaks to enable adding custom tangent dtypes
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other

this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
2023-12-22 11:33:14 -08:00
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
Lukas Geiger
52d7f4911c Prefer expand_dims over reshape 2023-11-16 01:15:48 +00:00
Jake VanderPlas
c0f3fa00f8 [random] support key dtype in custom_jvp
To do this, we introduce a dtype for key tangents which cannot be used
to generate random values
2023-11-10 11:16:23 -08:00
Jake VanderPlas
3e9c50290f Allow array-like inputs to random.seed_impl 2023-10-24 11:23:49 -07:00
Jake VanderPlas
8f82f2e66f [typing] regularize types of jax.random API 2023-10-20 10:33:20 -07:00
jax authors
dfcbfc3915 Merge pull request #18161 from jakevdp:prng-private-impl
PiperOrigin-RevId: 574679979
2023-10-18 18:57:02 -07:00
jax authors
6aff74e7ff Merge pull request #18162 from jakevdp:physical-aval
PiperOrigin-RevId: 574627738
2023-10-18 15:49:44 -07:00
Jake VanderPlas
0da4be5e2a [random] make PRNG impl attributes private 2023-10-18 11:10:47 -07:00
Jake VanderPlas
563673576e [random] cleanup internal implementation 2023-10-17 15:47:32 -07:00
Jake VanderPlas
6da4750c3b [random] remove internal uses of deprecated prng.seed_with_impl() 2023-10-17 13:18:08 -07:00
Jake VanderPlas
e5c2a2c0a3 [random] add shaped_abstractify handler for custom PRNG key 2023-10-10 16:15:19 -07:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
70e0098a87 [random] add itemsize property to custom PRNG 2023-09-25 08:52:26 -07:00
Roy Frostig
1f8cc44f4e deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Jake VanderPlas
ea5f126e85 [custom prng] make PRNGKeyArray a subclass of jax.Array 2023-09-12 13:48:12 -07:00
Roy Frostig
009932760c get aval directly via attribute in key array shard arg handler
No need to go through `core.get_aval` here.

PiperOrigin-RevId: 559945841
2023-08-24 19:47:35 -07:00
Peter Hawkins
889489206b Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
2023-08-17 06:44:12 -07:00
Roy Frostig
e58f5d283a de-emphasize internal array implementation type in key array repr 2023-08-14 12:48:19 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Jake VanderPlas
7d7a536b55 custom prng: introduce mechanism to identify key arrays by dtype 2023-07-21 12:27:32 -07:00
jax authors
1b33a4eb05 Merge pull request #16815 from hawkinsp:py39
PiperOrigin-RevId: 550014612
2023-07-21 12:12:47 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
2ffa9bd8df Refactor opaque dtype implementation.
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
2023-07-20 19:51:52 -07:00
Jake VanderPlas
b9c7b9bb4f Remove obsolete jaxlib version checks 2023-07-12 11:53:55 -07:00
Roy Frostig
df2891ff13 accept general shape option in jax.random.split
Several PRNG implementations (notably partitionable threefry) support
splitting to arbitrary shapes, rather than only to a 1-D vector of
keys. This change:

* Upgrades `jax.random.split` to accept a general shape as an
  argument.
* Updates the internal PRNG interface, and our various PRNG
  implementations, to accept and handle such a shape argument.

This change keeps the argument name `num`. We can still think on
whether and how we'd like to upgrade to `shape`.

Note that we could have supported arbitrary shapes by reduction to the
previous API (with a flat split count), using reshapes. We'd like to
avoid that, so as not to hide this structure from the underlying
implementation. For instance, partitionable threefry hashes a *shaped*
iota in order to split keys, and we don't want to flatten and reshape
around that for no reason.

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-07-11 14:24:01 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Roy Frostig
690e626312 outline jitted threefry split and fold_in subroutines
We may want to continue to inline these in Jaxpr, but it's useful to
outline them in HLO for visualization and debugging.
2023-06-26 11:52:55 -07:00