235 Commits

Author SHA1 Message Date
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
Peter Hawkins
23e9142d28 Lower threefry as an out-of-line MLIR function on TPU.
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
2024-11-15 08:49:35 -08:00
jax authors
4363bb65d7 Merge pull request #24770 from jakevdp:extended-device-get
PiperOrigin-RevId: 695671688
2024-11-12 03:58:23 -08:00
jax authors
c8f5b2bb13 Merge pull request #24481 from jakevdp:key-array-error
PiperOrigin-RevId: 694626415
2024-11-08 13:47:05 -08:00
Jake VanderPlas
58dee3ea33 jax.device_get: handle generic extended dtypes 2024-11-07 16:01:22 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Roy Frostig
69878c4924 remove Threefry GPU kernel
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.

PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
Jake VanderPlas
7e60331cd1 [key reuse] print information about key reuse location 2024-03-21 13:34:26 -07:00
Yash Katariya
ab2e906323 Fix the indentation of the physical_hlo_sharding function
PiperOrigin-RevId: 616280971
2024-03-15 16:59:20 -07:00
Yash Katariya
cd1e55a351 Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 616267810
2024-03-15 16:02:13 -07:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Yash Katariya
550ce44afd Move the replicated trailing dims check inside logical_op_sharding
PiperOrigin-RevId: 611277405
2024-02-28 17:03:37 -08:00
Yash Katariya
217f08236e Allow sharding propagation to input for prng keys whose sharding is not specified.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.

PiperOrigin-RevId: 611246986
2024-02-28 15:22:16 -08:00
Yash Katariya
2f7c36c763 Contrain the trailing dims of prng key array to REPLICATED and keep other dims as unconstrained.
PiperOrigin-RevId: 611232967
2024-02-28 14:38:43 -08:00
Jake VanderPlas
49eb7008c0 Define reuse_key primitive in jax._src.prng 2024-02-14 14:01:08 -08:00
Roy Frostig
a04332504b remove PRNGKeyArray ABC
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
2024-01-29 12:41:26 -08:00
Roy Frostig
2478f311d3 remove key array's isinstance-overriding metaclass
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.

PiperOrigin-RevId: 601815616
2024-01-26 11:16:56 -08:00
Jake VanderPlas
78f27dfa9d Remove unnecessary Array.register 2024-01-24 14:59:25 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Jake VanderPlas
fff5ea579a Remove deprecated unsafe_raw_array method from PRNG keys
PiperOrigin-RevId: 595190146
2024-01-02 13:03:21 -08:00
Matthew Johnson
05da18ab54 tweaks to enable adding custom tangent dtypes
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other

this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
2023-12-22 11:33:14 -08:00
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00