On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`
PiperOrigin-RevId: 629763763
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.
PiperOrigin-RevId: 629282330
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.
PiperOrigin-RevId: 616267810
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.
PiperOrigin-RevId: 613289252
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.
PiperOrigin-RevId: 611246986
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.
PiperOrigin-RevId: 601815616
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other
this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change
We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
Specifically:
* Introduce `jax.random.key_impl`, which accepts a key array and
returns a hashable identifier of its PRNG implementation.
* Accept this identifier optionally as the `impl` argument to
`jax.random.key` and `wrap_key_data`.
This now works:
```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```
This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.
We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.
To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
Instead, force the caller to explicitly canonicalize the argument if that's what they want.
The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.
PiperOrigin-RevId: 557806903
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
Several PRNG implementations (notably partitionable threefry) support
splitting to arbitrary shapes, rather than only to a 1-D vector of
keys. This change:
* Upgrades `jax.random.split` to accept a general shape as an
argument.
* Updates the internal PRNG interface, and our various PRNG
implementations, to accept and handle such a shape argument.
This change keeps the argument name `num`. We can still think on
whether and how we'd like to upgrade to `shape`.
Note that we could have supported arbitrary shapes by reduction to the
previous API (with a flat split count), using reshapes. We'd like to
avoid that, so as not to hide this structure from the underlying
implementation. For instance, partitionable threefry hashes a *shaped*
iota in order to split keys, and we don't want to flatten and reshape
around that for no reason.
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension