23755 Commits

Author SHA1 Message Date
dependabot[bot]
eaa099b9b7
Bump actions/cache from 4.0.2 to 4.1.0
Bumps [actions/cache](https://github.com/actions/cache) from 4.0.2 to 4.1.0.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](0c45773b62...2cdf405574)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-07 17:58:32 +00:00
dependabot[bot]
ab41f4871b
Bump actions/upload-artifact from 4.4.0 to 4.4.1
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.4.0 to 4.4.1.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](50769540e7...604373da63)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-07 17:58:27 +00:00
dependabot[bot]
d083c52ba6
Bump actions/checkout from 4.2.0 to 4.2.1
Bumps [actions/checkout](https://github.com/actions/checkout) from 4.2.0 to 4.2.1.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](d632683dd7...eef61447b9)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-07 17:58:24 +00:00
Christos Perivolaropoulos
9ac6723561 [pallas:mosaic_gpu] Dereferencing the accumulator now supports slicing
PiperOrigin-RevId: 683235013
2024-10-07 10:33:08 -07:00
Kristian Hartikainen
1ea8e3c29d Update _cuda_path
- Remove jax-relative module path test
- Use `$CUDA_ROOT` environment variable if available
- Use `cuda_nvcc` module's path if installed
2024-10-07 20:32:05 +03:00
Christos Perivolaropoulos
e8cea0d7a4 [mosaic_gpu] Load/store using ref indices rather than 1D to accomodate for strided refs.
PiperOrigin-RevId: 683223800
2024-10-07 10:03:31 -07:00
Peter Hawkins
a9926f0f01 Remove classic HLO lowering rule support from JAX.
(JAX uses StableHLO always, now, with the exception of one use case in jax2tf.)

PiperOrigin-RevId: 683205145
2024-10-07 09:06:20 -07:00
George Necula
b172a074b8 [jax2tf] Improve the non-native serialization deprecation warnings
PiperOrigin-RevId: 683197146
2024-10-07 08:41:16 -07:00
Peter Hawkins
145304a0e0 Remove reference to outfeed_receiver.pyi, which was deleted.
PiperOrigin-RevId: 683195999
2024-10-07 08:37:14 -07:00
jax authors
871ef43aa9 Merge pull request #24154 from apaszke:bump-oldest
PiperOrigin-RevId: 683191614
2024-10-07 08:22:45 -07:00
jax authors
a0ab2a79fb Merge pull request #24122 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 683167116
2024-10-07 06:58:07 -07:00
jax authors
4cb3a6d30d Merge pull request #24118 from jakevdp:fft-docs
PiperOrigin-RevId: 683158665
2024-10-07 06:29:35 -07:00
jax authors
eae7bb731b Merge pull request #24146 from jjyyxx:patch-4
PiperOrigin-RevId: 683158550
2024-10-07 06:28:08 -07:00
Sergei Lebedev
b289e3a2a9 [pallas:mosaic_gpu] Updated some of the docstrings
PiperOrigin-RevId: 683151058
2024-10-07 06:04:37 -07:00
Sergei Lebedev
06c08bd118 Renamed :pallas_gpu to :pallas_triton
:pallas_gpu is now an umbrella target for Triton and (hopefully soon)
Mosaic GPU backends.

PiperOrigin-RevId: 683145270
2024-10-07 05:44:00 -07:00
Adam Paszke
6508b52dab Bump "oldest supported" libtpu version to fall within the 12w window in CI 2024-10-07 12:27:55 +00:00
George Necula
5fabd34e7e [jax2tf] Remove non-native serialization test from jax_to_ir_test
PiperOrigin-RevId: 683124315
2024-10-07 04:21:38 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
jax authors
6d2c8cf5de Merge pull request #23656 from tchatow:fix-inv
PiperOrigin-RevId: 683112267
2024-10-07 03:38:04 -07:00
jax authors
896cf5a855 Remove reference cycle in MosaicGridMapping
This allows `MosaicGridMapping` instances to be deleted immediately
when their reference counts drop to 0, rather than leaving them to
be deleted only by the Python GC.

PiperOrigin-RevId: 682997031
2024-10-06 18:33:32 -07:00
Yash Katariya
c6f7316d43 Add a private _extremely_unsafe_enter_tracing_context to enter abstractMesh into tracing context. This is a temporary workaround for internal use cases.
PiperOrigin-RevId: 682960902
2024-10-06 14:50:24 -07:00
jax authors
e16fac67da Update XLA dependency to use revision
f5bcd00ab5.

PiperOrigin-RevId: 682920396
2024-10-06 10:34:22 -07:00
rajasekharporeddy
9832a11c50 Better doc for jnp.hypot 2024-10-06 21:58:03 +05:30
Yuxuan Jiang
757a77ede0
Fix wrong date in changelog 2024-10-06 23:16:30 +08:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
jax authors
81a31f6adf Update XLA dependency to use revision
9610649d64.

PiperOrigin-RevId: 682683663
2024-10-05 10:08:03 -07:00
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
jax authors
e90487e906 Host Offloading: Process "MoveToHost" instructions in the order they are executed.
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.

PiperOrigin-RevId: 682448060
2024-10-04 14:17:36 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
jax authors
291619c291 Allow custom call computations to contain subcomputations
PiperOrigin-RevId: 682429391
2024-10-04 13:22:14 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
jax authors
18f48bd52a Merge pull request #24129 from hawkinsp:asan
PiperOrigin-RevId: 682378527
2024-10-04 10:53:22 -07:00
jax authors
d48d96c157 Merge pull request #24124 from hawkinsp:shims
PiperOrigin-RevId: 682363064
2024-10-04 10:11:05 -07:00
jax authors
7b5842c355 Update XLA dependency to use revision
9e28b00207.

PiperOrigin-RevId: 682362646
2024-10-04 10:09:28 -07:00
Yash Katariya
83b0a932bd Update device info printed in print_environment_info to print the TPU kind also along with global and local device count.
PiperOrigin-RevId: 682345577
2024-10-04 09:20:27 -07:00
jax authors
46b7bfae91 Merge pull request #24113 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 682344324
2024-10-04 09:15:44 -07:00
Peter Hawkins
44998609e4 Remove extraneous backslash in asan.yaml build. 2024-10-04 11:58:24 -04:00
jax authors
8f423e0bd9 Merge pull request #24127 from gnecula:tfx_jax2tf
PiperOrigin-RevId: 682338478
2024-10-04 08:57:42 -07:00
George Necula
dcd75b4315 [jax2tf] Disable test involving tan on older TF versions 2024-10-04 18:48:59 +03:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
jax authors
2556f9308a Merge pull request #24123 from hawkinsp:postrelease
PiperOrigin-RevId: 682325691
2024-10-04 08:13:29 -07:00
Peter Hawkins
b0b7a60e63 Merge branch 'release/0.4.34' 2024-10-04 10:56:18 -04:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Jake VanderPlas
162322fc70 Better docs for fftshift & ifftshift 2024-10-04 06:12:02 -07:00
Brian Wieder
633ac7eaa9 Clear caches on jax exit.
PiperOrigin-RevId: 682288160
2024-10-04 05:55:30 -07:00
Sergei Lebedev
aadb50905c [pallas:mosaic_gpu] Allowed indexing refs with scalars
The transforms do not yet handle this case, so only the basic indexing works.

PiperOrigin-RevId: 682273046
2024-10-04 04:54:37 -07:00
jax authors
ad6604d249 Merge pull request #24116 from jakevdp:fix-typos
PiperOrigin-RevId: 682271471
2024-10-04 04:48:33 -07:00
Jake VanderPlas
0a8f46a6ca Fix some typos in lax_numpy.py 2024-10-04 03:55:19 -07:00
George Necula
3d389a7fb4 [host_callback] Accelerate deprecation of host_callback.barrier_wait
The jax.experimental.host_callback module has been deprecated since March 2024.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682232942
2024-10-04 02:22:03 -07:00
rajasekharporeddy
321b9bc3f4 Better doc for jnp.heaviside 2024-10-04 14:23:13 +05:30