23755 Commits

Author SHA1 Message Date
jax authors
bdae9ac72e Merge pull request #24024 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 680669011
2024-09-30 12:09:02 -07:00
rajasekharporeddy
5904fe1563 Better doc for jnp.cbrt 2024-09-30 23:35:57 +05:30
Mathew Odden
9ff891dfa1 [ROCm] Remove broken legacy env vars
These env vars are no longer used or need and were
being set incorrectly.

[ROCm] Use specific amdgpu version for EL8 systems

We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.

[ROCm] Fix formatting on python files

Reformatted with black
2024-09-30 12:39:51 -05:00
Dan Foreman-Mackey
ff1c2ac152 Add a test for 64-bit precision of IFFT on GPU.
Fixes https://github.com/jax-ml/jax/issues/23827. The precision fix was in https://github.com/openxla/xla/pull/17598, which has now been integrated into JAX, but we add a test here based on the repro from https://github.com/jax-ml/jax/issues/23827.

PiperOrigin-RevId: 680633622
2024-09-30 10:38:16 -07:00
jax authors
504bc435c6 Update XLA dependency to use revision
8aa6da1e51.

PiperOrigin-RevId: 680631168
2024-09-30 10:33:36 -07:00
dependabot[bot]
3196b9b153
Bump actions/checkout from 4.1.1 to 4.2.0
Bumps [actions/checkout](https://github.com/actions/checkout) from 4.1.1 to 4.2.0.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](b4ffde65f4...d632683dd7)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-30 17:25:13 +00:00
Peter Hawkins
45cd77ad8c Simplify CI configuration.
PiperOrigin-RevId: 680607105
2024-09-30 09:32:09 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Sergei Lebedev
a046e21a1e [pallas:mosaic_gpu] Do not do mgpu.commit_shared if all outputs are invariant wrt sequential axes
PiperOrigin-RevId: 680565753
2024-09-30 07:25:46 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
jax authors
411928b966 Rollback because of breakages
Reverts 21fea5b0db7a8d3fcd9d6918b430b0ebdd4da3e5

PiperOrigin-RevId: 680552566
2024-09-30 07:23:36 -07:00
Ilia Sergachev
b320dc2e5e Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
2024-09-30 14:03:55 +00:00
Sergei Lebedev
b3fca90434 [pallas:mosaic_gpu] Do not DCE the jaxpr in the lowering pass
There isn't an obvious reason for doing DCE there.

PiperOrigin-RevId: 680534567
2024-09-30 05:39:55 -07:00
Adam Paszke
21fea5b0db [Pallas/MGPU] Undo transforms on refs before giving them back to the users
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 680520185
2024-09-30 04:43:08 -07:00
Sergei Lebedev
38d2a573fc Exposed sequential iteration index via pl.program_id in Pallas Mosaic GPU
PiperOrigin-RevId: 680502214
2024-09-30 03:35:58 -07:00
jax authors
2cfbdb6c40 Update XLA dependency to use revision
defd9fe717.

PiperOrigin-RevId: 680271168
2024-09-29 10:17:18 -07:00
Jake VanderPlas
7cc5ea4c7d Update docs for jnp.trapezoid 2024-09-29 10:08:27 -07:00
jax authors
6790b90f91 Update XLA dependency to use revision
b9fe7ae8ea.

PiperOrigin-RevId: 680003245
2024-09-28 10:54:17 -07:00
jax authors
15024baabf Merge pull request #23982 from dfm:ffi-call-effects
PiperOrigin-RevId: 679780356
2024-09-27 17:15:38 -07:00
Dan Foreman-Mackey
d80a89d86b Add support for FFI calls with side effects via ffi_call 2024-09-27 19:46:35 -04:00
Peter Hawkins
061f435b73 Bump test tolerance on FFT test that started failing in CI after an XLA change.
PiperOrigin-RevId: 679715691
2024-09-27 13:49:58 -07:00
Peter Hawkins
366c823857 Fix test failure when shardy is not enabled.
PiperOrigin-RevId: 679713372
2024-09-27 13:42:20 -07:00
Peter Hawkins
5969e79908 Fix tests that ask for an accelerator but don't use it.
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.

PiperOrigin-RevId: 679711664
2024-09-27 13:36:23 -07:00
jax authors
ff0a98a2ae Merge pull request #23957 from jakevdp:choice-doc
PiperOrigin-RevId: 679657248
2024-09-27 11:08:42 -07:00
jax authors
4e8a763c43 Merge pull request #23981 from jakevdp:checkout-pin
PiperOrigin-RevId: 679654332
2024-09-27 11:02:05 -07:00
jax authors
2963cbede1 Merge pull request #23977 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 679653515
2024-09-27 11:00:19 -07:00
Jake VanderPlas
f0e27fd73b actions/checkout: pin specific hash 2024-09-27 10:51:12 -07:00
Jake VanderPlas
c0612576de Better documentation for jnp.choose 2024-09-27 10:35:19 -07:00
jax authors
20122ff7ce Update XLA dependency to use revision
ce7c0120ab.

PiperOrigin-RevId: 679643612
2024-09-27 10:33:40 -07:00
rajasekharporeddy
c17ae0fd98 Improve docs for jax.numpy: arcsinh, arccosh and arctanh 2024-09-27 23:03:11 +05:30
jax authors
df042fded2 Merge pull request #23870 from Zantares:tenglu/flush_output
PiperOrigin-RevId: 679639244
2024-09-27 10:21:25 -07:00
jax authors
a6167585af Merge pull request #23804 from hawkinsp:vis
PiperOrigin-RevId: 679633739
2024-09-27 10:08:15 -07:00
jax authors
af12e0a178 Merge pull request #23961 from apivovarov:pylint2
PiperOrigin-RevId: 679633728
2024-09-27 10:07:53 -07:00
jax authors
898f057249 Merge pull request #23980 from jakevdp:update-actions
PiperOrigin-RevId: 679633706
2024-09-27 10:06:10 -07:00
jax authors
b762291183 Merge pull request #23965 from zhenying-liu:weight-offloading-test
PiperOrigin-RevId: 679631125
2024-09-27 10:01:30 -07:00
Jake VanderPlas
cb9f26c8fc actions: remove use of ratchet tool
We no longer use ratchet manually, but rather rely on dependabot
to update these versions when necessary.
2024-09-27 09:22:47 -07:00
Peter Hawkins
5a1d0a6c26 Include the sdy MLIR dialect in jaxlib.
We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build.

The size impact is small (around 400K uncompressed).

PiperOrigin-RevId: 679608092
2024-09-27 08:53:31 -07:00
Peter Hawkins
e4790b634e Don't pass --nocheck_visibility to Bazel.
This no longer appears to be needed.
2024-09-27 11:19:42 -04:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Adam Paszke
5740ab3b02 [Pallas/MGPU] Skip output transfers when they don't depend on sequenital dims
Note that thanks to the previous revisiting-related checks we weren't doing the
transfers anyway, but this way we can also avoid having to pay for the checks.

PiperOrigin-RevId: 679516275
2024-09-27 03:12:16 -07:00
Sergei Lebedev
afaf8b823d Run Pallas Mosaic GPU tests on internal CI
PiperOrigin-RevId: 679508320
2024-09-27 02:43:35 -07:00
Sergei Lebedev
3ae48621dd Fixed Pallas Mosaic GPU test following recent changes
PiperOrigin-RevId: 679504036
2024-09-27 02:28:37 -07:00
jax authors
ea86251a60 [Pallas:TPU] Fix lowering of convert_element_type(int32) -> bool.
We need to add a condition on vector type since both operands of arith::CmpIOp must have same type.

PiperOrigin-RevId: 679500783
2024-09-27 02:15:35 -07:00
Sergei Lebedev
ea6ee4d7fe Removed unused imports in jax.experimental.mosaic.gpu.core
PiperOrigin-RevId: 679498378
2024-09-27 02:08:00 -07:00
Ayaka
ab4590ce0a [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU
This fixes https://github.com/jax-ml/jax/issues/22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to https://github.com/jax-ml/jax/pull/23885.

PiperOrigin-RevId: 679487218
2024-09-27 01:33:08 -07:00
Jane Liu
57bef447c6 Enable weight offloading tests that are supported on GPUs now 2024-09-26 23:26:27 -07:00
jax authors
5a1549cccf Merge pull request #23853 from zhenying-liu:remat-scan
PiperOrigin-RevId: 679365040
2024-09-26 18:12:30 -07:00
Alexander Pivovarov
69193aa6a4 Remove pylint sections from pyproject.toml.
use ruff instead
2024-09-26 23:29:56 +00:00
Justin Fu
9f4e8d0039 [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
2024-09-26 13:57:45 -07:00
jax authors
46dbb6588a Merge pull request #23949 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 679264758
2024-09-26 13:09:17 -07:00