8695 Commits

Author SHA1 Message Date
jax authors
6d2c8cf5de Merge pull request #23656 from tchatow:fix-inv
PiperOrigin-RevId: 683112267
2024-10-07 03:38:04 -07:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
jax authors
e90487e906 Host Offloading: Process "MoveToHost" instructions in the order they are executed.
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.

PiperOrigin-RevId: 682448060
2024-10-04 14:17:36 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
jax authors
291619c291 Allow custom call computations to contain subcomputations
PiperOrigin-RevId: 682429391
2024-10-04 13:22:14 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Sergei Lebedev
aadb50905c [pallas:mosaic_gpu] Allowed indexing refs with scalars
The transforms do not yet handle this case, so only the basic indexing works.

PiperOrigin-RevId: 682273046
2024-10-04 04:54:37 -07:00
George Necula
3d389a7fb4 [host_callback] Accelerate deprecation of host_callback.barrier_wait
The jax.experimental.host_callback module has been deprecated since March 2024.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682232942
2024-10-04 02:22:03 -07:00
Yash Katariya
79ff8e6232 Cache the iteration over jaxpr equation when extracting shardings because majority of the time, it's the same jaxpr so we don't need to evaluate it again and again.
PiperOrigin-RevId: 682148975
2024-10-03 20:47:59 -07:00
jax authors
be76fb6abf Add host compute offload test: test_offload_take_host.
PiperOrigin-RevId: 682088063
2024-10-03 17:06:46 -07:00
Ayaka
cb2e0e2ced [Pallas TPU] Add lowering for lax.ceil_p
This PR uses exactly the same approach as https://github.com/jax-ml/jax/pull/24083, which adds lowering for `lax.floor_p`.

PiperOrigin-RevId: 682073765
2024-10-03 16:23:18 -07:00
jax authors
c6e5530aab Merge pull request #24081 from jakevdp:jnp-spacing
PiperOrigin-RevId: 681950616
2024-10-03 11:07:27 -07:00
Jake VanderPlas
635e29a0b9 Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
jax authors
7ef41ce653 Merge pull request #24095 from hawkinsp:asan
PiperOrigin-RevId: 681921579
2024-10-03 09:57:46 -07:00
Peter Hawkins
0f863b2092 Change asan build to use a self-hosted runner and Ubuntu 24.04. 2024-10-03 12:42:59 -04:00
Christos Perivolaropoulos
5800070c36 [pallas:mosaic_gpu] add logistic op and some tests for unary operations
PiperOrigin-RevId: 681889064
2024-10-03 08:25:44 -07:00
jax authors
ba4052d5ab Merge pull request #23881 from dfm:deprecate-default-vmap-callback
PiperOrigin-RevId: 681862488
2024-10-03 07:07:24 -07:00
Sergei Lebedev
905c83c781 [pallas:mosaic_gpu] Support indexing barriers
A barrier must be indexed via `.at` and not directly. I wish we could emit
an instructive error for the latter case, but I couldn't find a good place
to put it.

PiperOrigin-RevId: 681857034
2024-10-03 06:48:03 -07:00
Sergei Lebedev
5a2e5a5a94 [pallas:mosaic_gpu] Copy primitives now support slices
I decided to

* split `async_copy_p` into multiple primitives to avoid having extra
  control flow in the lowering rule;
* drop the `async_*` prefix from `async_copy_p` and the corresponding
  APIs, because the names felt a bit too long otherwise.

Note that barriers cannot be sliced at the moment. I will address that in
a follow up CL.

PiperOrigin-RevId: 681793650
2024-10-03 02:54:20 -07:00
Ayaka
b5ce44536b [Pallas TPU] Add lowering for lax.floor_p
This is a follow-up of https://github.com/jax-ml/jax/pull/24056, which adds lowering for `lax.tan_p`.

PiperOrigin-RevId: 681793238
2024-10-03 02:52:26 -07:00
Sergei Lebedev
4cf33c0239 Added scatter_sub_p
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 681754037
2024-10-03 00:27:31 -07:00
jax authors
81d2fbe094 Merge pull request #23740 from kaixih:dbias_bwd_batcher
PiperOrigin-RevId: 681583770
2024-10-02 14:04:19 -07:00
Adam Paszke
c9f946ef57 Only thread a discharged ref value through a cond when it changes in some branch
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.

PiperOrigin-RevId: 681482148
2024-10-02 09:29:07 -07:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
George Necula
b8a066a907 [host_callback] Remove obsolete tests.
Removing tests that only work in legacy mode and with outfeed.

PiperOrigin-RevId: 681435113
2024-10-02 06:51:02 -07:00
Adam Paszke
e2d3bd866a [Pallas/MGPU] Add support for tiled and swizzled loads/stores + support slices
PiperOrigin-RevId: 681370464
2024-10-02 02:44:10 -07:00
jax authors
fcfb0b7e48 Merge pull request #23951 from dfm:ffi-examples-attrs
PiperOrigin-RevId: 681219771
2024-10-01 17:14:05 -07:00
jax authors
e1439d5abe Merge pull request #24039 from mattjj:issue23867
PiperOrigin-RevId: 681219234
2024-10-01 17:12:00 -07:00
Sharad Vikram
c34e25d6f4 [Pallas] Add state discharge rule for pallas_call
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).

PiperOrigin-RevId: 681206784
2024-10-01 16:30:56 -07:00
Dan Foreman-Mackey
f60c5ccdee Add support for passing array attributes via ffi_call 2024-10-01 19:22:04 -04:00
Ayaka
e361868132 [Pallas TPU] Add lowering for lax.tan_p
This is a follow-up of https://github.com/jax-ml/jax/pull/24028, which adds lowering for `lax.cos_p`

PiperOrigin-RevId: 681180835
2024-10-01 15:09:52 -07:00
Peter Hawkins
fc4f554e09 Delete jax.lib.xla_client.execute_with_python_values.
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
2024-10-01 14:32:22 -07:00
jax authors
6ded1bee66 [Pallas:TPU] Fix lowering of convert_element_type(float32) -> bool.
The original implementation doesn't handle 0 < |x| < 1 correctly. It used to be convert_element_type(x, int32) ==> 0 ==> convert_element_type(0, bool) ==> false, which is different from XLA semantics: convert_element_type(x, bool) ==> true.

Hypothesis library seems to draw values of 0.5.

While I'm here, remove some stale skip conditions. They are fixed due to recent Pallas/Mosaic changes.

PiperOrigin-RevId: 681152158
2024-10-01 13:51:47 -07:00
Blake Hechtman
ce21a12a07 [JAX] Make a one hot mode of take along axis.
PiperOrigin-RevId: 681139055
2024-10-01 13:16:26 -07:00
Yash Katariya
1efca33187 Add donate and may_alias as an argument to device_put to allow for donation and aliasing.
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.

**Definition:**

* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.

* may_alias: If True, we may return the original buffer depending on the implementation.

**What problem are we solving?**

Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.

Adding `donate` allows users to avoid this pattern of code:

```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```

Now it can just be: `jax.device_put(inp, sharding, donate=True)`

**So what are the semantics of these 2 options?** Let's create a table:

| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |

`donate` is best effort for now until we fix the following things:

 * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.

 * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.

PiperOrigin-RevId: 681073828
2024-10-01 10:28:23 -07:00
Peter Hawkins
1260ebbe05 Disable cudnn_fusion_test on A100.
This test only seems to pass on H100 at the moment.

PiperOrigin-RevId: 681070398
2024-10-01 10:18:41 -07:00
jax authors
9ba90741a8 Merge pull request #23984 from jakevdp:mask-indices-doc
PiperOrigin-RevId: 681053740
2024-10-01 09:35:30 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Christos Perivolaropoulos
84fc011e27 Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 681021324
2024-10-01 08:03:58 -07:00
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
Sergei Lebedev
0cfed4efad [pallas:mosaic_gpu] Shrink max_concurrent_iteration based on the total number of steps
PiperOrigin-RevId: 680990842
2024-10-01 06:19:43 -07:00
George Necula
a644e23a4b [host_callback] Skip test that only works in legacy mode.
The jax.experimental.host_callback module is deprecated and will be removed.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 680988939
2024-10-01 06:13:29 -07:00
Adam Paszke
98b72b17f9 [Pallas/MGPU] Add support for transforms and swizzles on outputs
PiperOrigin-RevId: 680982318
2024-10-01 05:56:35 -07:00
Adam Paszke
cac2b8d5fc [Pallas/MGPU] Undo transforms before giving refs back to users
This is a second attempt at this change. The first one was rolled back because of reported failures.

Reverts 411928b9668570bbc3795522aba94cece6894881

PiperOrigin-RevId: 680943744
2024-10-01 03:32:40 -07:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Matthew Johnson
11fdda9583 add checkify rule for remat
fixes #23867
2024-10-01 02:01:18 +00:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00