249 Commits

Author SHA1 Message Date
Roy Frostig
90af597786 remove inaccurate inline comment in PRNGKeyArray constructor
PiperOrigin-RevId: 748085747
2025-04-15 17:39:40 -07:00
Roy Frostig
47bc2f55dc convert NumPy RNG key data to uncommitted default-device-backed jax.Array data
Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 748077900
2025-04-15 17:11:25 -07:00
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
George Necula
ce7dc85104 [export] Add support for serializing functions with PRNG keys as inputs/outputs
This introduces version 4 of serialization, fully backwards compatible
with versions 2 and 3.

Fixes: #24143
2025-04-07 11:53:20 +02:00
jax authors
47876bb3dc Merge pull request #27579 from ZacCranko:nbytes
PiperOrigin-RevId: 741636333
2025-03-28 13:50:40 -07:00
Zac Cranko
d4c42d7199 implement nbytes for PRNGKeyArray 2025-03-28 10:54:48 -07:00
Peter Hawkins
ecd9f5ded8 Move aval_to_xla_shape into callback.py, which is its only user.
Specialize it to one shape per aval, since that's the only case that exists.
Remove some pointless assertions using this code.

PiperOrigin-RevId: 741569024
2025-03-28 10:28:04 -07:00
Yash Katariya
5950e722e2 Make sure vma on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00
Yash Katariya
563c3e2244 Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Yash Katariya
289fa625e5 [sharding_in_types] Add fold_in support
PiperOrigin-RevId: 740505750
2025-03-25 15:29:32 -07:00
Dan Foreman-Mackey
d7d0aa943e Move PRNG GPU lowering from jaxlib into JAX.
PiperOrigin-RevId: 738398099
2025-03-19 07:57:10 -07:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
Yash Katariya
aeac6b0383 Fix pmap with sharded typed prng key
PiperOrigin-RevId: 714293671
2025-01-10 18:20:09 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
Peter Hawkins
23e9142d28 Lower threefry as an out-of-line MLIR function on TPU.
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
2024-11-15 08:49:35 -08:00
jax authors
4363bb65d7 Merge pull request #24770 from jakevdp:extended-device-get
PiperOrigin-RevId: 695671688
2024-11-12 03:58:23 -08:00
jax authors
c8f5b2bb13 Merge pull request #24481 from jakevdp:key-array-error
PiperOrigin-RevId: 694626415
2024-11-08 13:47:05 -08:00
Jake VanderPlas
58dee3ea33 jax.device_get: handle generic extended dtypes 2024-11-07 16:01:22 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Roy Frostig
69878c4924 remove Threefry GPU kernel
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.

PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
Jake VanderPlas
7e60331cd1 [key reuse] print information about key reuse location 2024-03-21 13:34:26 -07:00
Yash Katariya
ab2e906323 Fix the indentation of the physical_hlo_sharding function
PiperOrigin-RevId: 616280971
2024-03-15 16:59:20 -07:00
Yash Katariya
cd1e55a351 Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 616267810
2024-03-15 16:02:13 -07:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00