23314 Commits

Author SHA1 Message Date
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
Sergei Lebedev
0cfed4efad [pallas:mosaic_gpu] Shrink max_concurrent_iteration based on the total number of steps
PiperOrigin-RevId: 680990842
2024-10-01 06:19:43 -07:00
George Necula
a644e23a4b [host_callback] Skip test that only works in legacy mode.
The jax.experimental.host_callback module is deprecated and will be removed.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 680988939
2024-10-01 06:13:29 -07:00
Adam Paszke
98b72b17f9 [Pallas/MGPU] Add support for transforms and swizzles on outputs
PiperOrigin-RevId: 680982318
2024-10-01 05:56:35 -07:00
Adam Paszke
7f655972c4 [Pallas/MGPU] Make swizzle a Pallas transform
This will be useful in that we'll be able to read the ref swizzling when it will be passed to
load/store ops.

PiperOrigin-RevId: 680955632
2024-10-01 04:15:31 -07:00
Adam Paszke
da5f2a3c13 [Pallas/MGPU] Check for trivial indexers in get/swap lowering rules
PiperOrigin-RevId: 680949406
2024-10-01 03:53:24 -07:00
Adam Paszke
cac2b8d5fc [Pallas/MGPU] Undo transforms before giving refs back to users
This is a second attempt at this change. The first one was rolled back because of reported failures.

Reverts 411928b9668570bbc3795522aba94cece6894881

PiperOrigin-RevId: 680943744
2024-10-01 03:32:40 -07:00
Sergei Lebedev
14ef2b6a21 [pallas:mosaic_gpu] Removed a stale TODO
PiperOrigin-RevId: 680931423
2024-10-01 02:44:54 -07:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
jax authors
31cb3fd36e Merge pull request #23923 from carlosgmartin:ldexp_custom_jvp
PiperOrigin-RevId: 680757259
2024-09-30 16:21:57 -07:00
Ayaka
a24420e76b [Pallas TPU] Add lowering for lax.cos_p
Fixes https://github.com/jax-ml/jax/issues/24026

PiperOrigin-RevId: 680754948
2024-09-30 16:12:11 -07:00
Ayaka
23ce5a11cc [Pallas TPU] Consolidate OpsExtraTest into OpsTest
Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.

PiperOrigin-RevId: 680747902
2024-09-30 15:50:23 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
jax authors
c557db0bd8 Merge pull request #23995 from jakevdp:trapezoid-doc
PiperOrigin-RevId: 680734292
2024-09-30 15:10:16 -07:00
Sergei Lebedev
d74d3daa0e [pallas:triton] Do not DCE the jaxpr in the lowering pass
There isn't an obvious reason for doing DCE there, and the Mosaic TPU backend
in fact doesn't DCE.

PiperOrigin-RevId: 680710736
2024-09-30 14:04:37 -07:00
Jevin Jiang
4a596aee1e [Mosaic TPU] Force offset to 0 when inferring input has offset out of the first tile.
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.

PiperOrigin-RevId: 680706301
2024-09-30 13:52:48 -07:00
jax authors
cdc72787fc Merge pull request #24025 from jakevdp:gradient-doc
PiperOrigin-RevId: 680703792
2024-09-30 13:48:09 -07:00
Jake VanderPlas
36d6bb9013 Better docs for jnp.gradient
Also remove skip_params option from util.implements, as this was its last usage.
2024-09-30 13:07:52 -07:00
jax authors
3766f887d3 Merge pull request #23505 from sergachev:fix_cudnn_fusion_test
PiperOrigin-RevId: 680685919
2024-09-30 12:58:44 -07:00
Jevin Jiang
7e2f487ada [Mosaic TPU] Canonicalize arith.select's condition to vector if other types are vector.
This fixes the failure in elementwise rule of apply vector layout pass.

If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.

If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.

PiperOrigin-RevId: 680674560
2024-09-30 12:26:44 -07:00
jax authors
bdae9ac72e Merge pull request #24024 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 680669011
2024-09-30 12:09:02 -07:00
rajasekharporeddy
5904fe1563 Better doc for jnp.cbrt 2024-09-30 23:35:57 +05:30
Dan Foreman-Mackey
ff1c2ac152 Add a test for 64-bit precision of IFFT on GPU.
Fixes https://github.com/jax-ml/jax/issues/23827. The precision fix was in https://github.com/openxla/xla/pull/17598, which has now been integrated into JAX, but we add a test here based on the repro from https://github.com/jax-ml/jax/issues/23827.

PiperOrigin-RevId: 680633622
2024-09-30 10:38:16 -07:00
jax authors
504bc435c6 Update XLA dependency to use revision
8aa6da1e51.

PiperOrigin-RevId: 680631168
2024-09-30 10:33:36 -07:00
Peter Hawkins
45cd77ad8c Simplify CI configuration.
PiperOrigin-RevId: 680607105
2024-09-30 09:32:09 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Sergei Lebedev
a046e21a1e [pallas:mosaic_gpu] Do not do mgpu.commit_shared if all outputs are invariant wrt sequential axes
PiperOrigin-RevId: 680565753
2024-09-30 07:25:46 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
jax authors
411928b966 Rollback because of breakages
Reverts 21fea5b0db7a8d3fcd9d6918b430b0ebdd4da3e5

PiperOrigin-RevId: 680552566
2024-09-30 07:23:36 -07:00
Ilia Sergachev
b320dc2e5e Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
2024-09-30 14:03:55 +00:00
Sergei Lebedev
b3fca90434 [pallas:mosaic_gpu] Do not DCE the jaxpr in the lowering pass
There isn't an obvious reason for doing DCE there.

PiperOrigin-RevId: 680534567
2024-09-30 05:39:55 -07:00
Adam Paszke
21fea5b0db [Pallas/MGPU] Undo transforms on refs before giving them back to the users
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 680520185
2024-09-30 04:43:08 -07:00
Sergei Lebedev
38d2a573fc Exposed sequential iteration index via pl.program_id in Pallas Mosaic GPU
PiperOrigin-RevId: 680502214
2024-09-30 03:35:58 -07:00
jax authors
2cfbdb6c40 Update XLA dependency to use revision
defd9fe717.

PiperOrigin-RevId: 680271168
2024-09-29 10:17:18 -07:00
Jake VanderPlas
7cc5ea4c7d Update docs for jnp.trapezoid 2024-09-29 10:08:27 -07:00
jax authors
6790b90f91 Update XLA dependency to use revision
b9fe7ae8ea.

PiperOrigin-RevId: 680003245
2024-09-28 10:54:17 -07:00
jax authors
15024baabf Merge pull request #23982 from dfm:ffi-call-effects
PiperOrigin-RevId: 679780356
2024-09-27 17:15:38 -07:00
Dan Foreman-Mackey
d80a89d86b Add support for FFI calls with side effects via ffi_call 2024-09-27 19:46:35 -04:00
Peter Hawkins
061f435b73 Bump test tolerance on FFT test that started failing in CI after an XLA change.
PiperOrigin-RevId: 679715691
2024-09-27 13:49:58 -07:00
Peter Hawkins
366c823857 Fix test failure when shardy is not enabled.
PiperOrigin-RevId: 679713372
2024-09-27 13:42:20 -07:00
Peter Hawkins
5969e79908 Fix tests that ask for an accelerator but don't use it.
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.

PiperOrigin-RevId: 679711664
2024-09-27 13:36:23 -07:00
jax authors
ff0a98a2ae Merge pull request #23957 from jakevdp:choice-doc
PiperOrigin-RevId: 679657248
2024-09-27 11:08:42 -07:00
jax authors
4e8a763c43 Merge pull request #23981 from jakevdp:checkout-pin
PiperOrigin-RevId: 679654332
2024-09-27 11:02:05 -07:00
jax authors
2963cbede1 Merge pull request #23977 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 679653515
2024-09-27 11:00:19 -07:00
Jake VanderPlas
f0e27fd73b actions/checkout: pin specific hash 2024-09-27 10:51:12 -07:00
Jake VanderPlas
c0612576de Better documentation for jnp.choose 2024-09-27 10:35:19 -07:00
jax authors
20122ff7ce Update XLA dependency to use revision
ce7c0120ab.

PiperOrigin-RevId: 679643612
2024-09-27 10:33:40 -07:00
rajasekharporeddy
c17ae0fd98 Improve docs for jax.numpy: arcsinh, arccosh and arctanh 2024-09-27 23:03:11 +05:30
jax authors
df042fded2 Merge pull request #23870 from Zantares:tenglu/flush_output
PiperOrigin-RevId: 679639244
2024-09-27 10:21:25 -07:00