23351 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
jax authors
816947b656 Merge pull request #24071 from hawkinsp:winci
PiperOrigin-RevId: 681451098
2024-10-02 07:51:12 -07:00
jax authors
e212c77336 Merge pull request #23891 from ROCm:build-fixes-rollup
PiperOrigin-RevId: 681448694
2024-10-02 07:43:13 -07:00
George Necula
b8a066a907 [host_callback] Remove obsolete tests.
Removing tests that only work in legacy mode and with outfeed.

PiperOrigin-RevId: 681435113
2024-10-02 06:51:02 -07:00
Peter Hawkins
9400ef57a1 Fix Windows CI to install the jaxlib wheel it builds.
Currently we install jax first and then try to install jaxlib. But pip won't overwrite the jaxlib installed as a jax dependency, so our CI builds were always using the released jaxlib.
2024-10-02 09:48:14 -04:00
jax authors
8df1623eb5 Merge pull request #24065 from eltociear:patch-9
PiperOrigin-RevId: 681418901
2024-10-02 05:52:04 -07:00
Adam Paszke
e2d3bd866a [Pallas/MGPU] Add support for tiled and swizzled loads/stores + support slices
PiperOrigin-RevId: 681370464
2024-10-02 02:44:10 -07:00
jax authors
d0a9bb6c2a Merge pull request #24032 from Pixee-Bot-Python:main
PiperOrigin-RevId: 681346631
2024-10-02 01:25:50 -07:00
Ikko Eltociear Ashimine
a7c6935994
docs: update Custom_Operation_for_GPUs.md
implementaion -> implementation
2024-10-02 12:57:45 +09:00
jax authors
fcfb0b7e48 Merge pull request #23951 from dfm:ffi-examples-attrs
PiperOrigin-RevId: 681219771
2024-10-01 17:14:05 -07:00
jax authors
e1439d5abe Merge pull request #24039 from mattjj:issue23867
PiperOrigin-RevId: 681219234
2024-10-01 17:12:00 -07:00
Sharad Vikram
c34e25d6f4 [Pallas] Add state discharge rule for pallas_call
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).

PiperOrigin-RevId: 681206784
2024-10-01 16:30:56 -07:00
Dan Foreman-Mackey
f60c5ccdee Add support for passing array attributes via ffi_call 2024-10-01 19:22:04 -04:00
Ayaka
e361868132 [Pallas TPU] Add lowering for lax.tan_p
This is a follow-up of https://github.com/jax-ml/jax/pull/24028, which adds lowering for `lax.cos_p`

PiperOrigin-RevId: 681180835
2024-10-01 15:09:52 -07:00
Peter Hawkins
fc4f554e09 Delete jax.lib.xla_client.execute_with_python_values.
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
2024-10-01 14:32:22 -07:00
pixeeai
4f8c1277fe Add timeout to requests calls (#1) 2024-10-01 17:26:54 -04:00
Jake VanderPlas
49ad220e57 Finalize deprecation of XLACompatibleSharding
PiperOrigin-RevId: 681156145
2024-10-01 14:02:34 -07:00
jax authors
6ded1bee66 [Pallas:TPU] Fix lowering of convert_element_type(float32) -> bool.
The original implementation doesn't handle 0 < |x| < 1 correctly. It used to be convert_element_type(x, int32) ==> 0 ==> convert_element_type(0, bool) ==> false, which is different from XLA semantics: convert_element_type(x, bool) ==> true.

Hypothesis library seems to draw values of 0.5.

While I'm here, remove some stale skip conditions. They are fixed due to recent Pallas/Mosaic changes.

PiperOrigin-RevId: 681152158
2024-10-01 13:51:47 -07:00
Blake Hechtman
ce21a12a07 [JAX] Make a one hot mode of take along axis.
PiperOrigin-RevId: 681139055
2024-10-01 13:16:26 -07:00
jax authors
afed9f44bd Merge pull request #24020 from jax-ml:dependabot/github_actions/actions/checkout-4.2.0
PiperOrigin-RevId: 681131782
2024-10-01 12:56:53 -07:00
Yash Katariya
0093ba29d8 Fix jax2tf error after changing the signature of device_put_p
PiperOrigin-RevId: 681126203
2024-10-01 12:40:23 -07:00
jax authors
9ad7e2eb42 Merge pull request #24048 from jakevdp:packbits-doc
PiperOrigin-RevId: 681093674
2024-10-01 11:15:19 -07:00
jax authors
23c46ff44e Update XLA dependency to use revision
869808b12d.

PiperOrigin-RevId: 681075946
2024-10-01 10:33:43 -07:00
Yash Katariya
1efca33187 Add donate and may_alias as an argument to device_put to allow for donation and aliasing.
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.

**Definition:**

* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.

* may_alias: If True, we may return the original buffer depending on the implementation.

**What problem are we solving?**

Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.

Adding `donate` allows users to avoid this pattern of code:

```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```

Now it can just be: `jax.device_put(inp, sharding, donate=True)`

**So what are the semantics of these 2 options?** Let's create a table:

| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |

`donate` is best effort for now until we fix the following things:

 * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.

 * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.

PiperOrigin-RevId: 681073828
2024-10-01 10:28:23 -07:00
Justin Fu
350afaa7b6 [Pallas] Clean up lowering exceptions.
PiperOrigin-RevId: 681073628
2024-10-01 10:26:40 -07:00
Peter Hawkins
1260ebbe05 Disable cudnn_fusion_test on A100.
This test only seems to pass on H100 at the moment.

PiperOrigin-RevId: 681070398
2024-10-01 10:18:41 -07:00
jax authors
28098bef93 Merge pull request #24034 from jakevdp:ldexp-doc
PiperOrigin-RevId: 681058537
2024-10-01 09:47:56 -07:00
jax authors
741115a0fc Update XLA dependency to use revision
e1b38f898e.

PiperOrigin-RevId: 681055200
2024-10-01 09:39:34 -07:00
jax authors
9ba90741a8 Merge pull request #23984 from jakevdp:mask-indices-doc
PiperOrigin-RevId: 681053740
2024-10-01 09:35:30 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Christos Perivolaropoulos
84fc011e27 Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 681021324
2024-10-01 08:03:58 -07:00
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
Sergei Lebedev
0cfed4efad [pallas:mosaic_gpu] Shrink max_concurrent_iteration based on the total number of steps
PiperOrigin-RevId: 680990842
2024-10-01 06:19:43 -07:00
George Necula
a644e23a4b [host_callback] Skip test that only works in legacy mode.
The jax.experimental.host_callback module is deprecated and will be removed.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 680988939
2024-10-01 06:13:29 -07:00
Adam Paszke
98b72b17f9 [Pallas/MGPU] Add support for transforms and swizzles on outputs
PiperOrigin-RevId: 680982318
2024-10-01 05:56:35 -07:00
Jake VanderPlas
ae374e0096 Document jnp.packbits & jnp.unpackbits 2024-10-01 05:33:25 -07:00
Adam Paszke
7f655972c4 [Pallas/MGPU] Make swizzle a Pallas transform
This will be useful in that we'll be able to read the ref swizzling when it will be passed to
load/store ops.

PiperOrigin-RevId: 680955632
2024-10-01 04:15:31 -07:00
Adam Paszke
da5f2a3c13 [Pallas/MGPU] Check for trivial indexers in get/swap lowering rules
PiperOrigin-RevId: 680949406
2024-10-01 03:53:24 -07:00
Adam Paszke
cac2b8d5fc [Pallas/MGPU] Undo transforms before giving refs back to users
This is a second attempt at this change. The first one was rolled back because of reported failures.

Reverts 411928b9668570bbc3795522aba94cece6894881

PiperOrigin-RevId: 680943744
2024-10-01 03:32:40 -07:00
Sergei Lebedev
14ef2b6a21 [pallas:mosaic_gpu] Removed a stale TODO
PiperOrigin-RevId: 680931423
2024-10-01 02:44:54 -07:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Matthew Johnson
11fdda9583 add checkify rule for remat
fixes #23867
2024-10-01 02:01:18 +00:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Jake VanderPlas
22906b06ed Improve docs for jnp.ldexp and jnp.frexp 2024-09-30 16:38:21 -07:00
jax authors
31cb3fd36e Merge pull request #23923 from carlosgmartin:ldexp_custom_jvp
PiperOrigin-RevId: 680757259
2024-09-30 16:21:57 -07:00
Ayaka
a24420e76b [Pallas TPU] Add lowering for lax.cos_p
Fixes https://github.com/jax-ml/jax/issues/24026

PiperOrigin-RevId: 680754948
2024-09-30 16:12:11 -07:00
Ayaka
23ce5a11cc [Pallas TPU] Consolidate OpsExtraTest into OpsTest
Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.

PiperOrigin-RevId: 680747902
2024-09-30 15:50:23 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Jake VanderPlas
36782e8319 jnp.mask_indices: add docs & tests 2024-09-30 15:13:41 -07:00
jax authors
c557db0bd8 Merge pull request #23995 from jakevdp:trapezoid-doc
PiperOrigin-RevId: 680734292
2024-09-30 15:10:16 -07:00