21394 Commits

Author SHA1 Message Date
jax authors
f4158ace93 Merge pull request #21949 from hawkinsp:winwheel
PiperOrigin-RevId: 644346629
rocm-jaxlib-v0.4.30
2024-06-18 05:21:34 -07:00
Peter Hawkins
b0b02682d5 Add --allow-downgrade to Windows wheel builds.
We want a specific LLVM version, downgrades are fine.
Also add --no-progress for a more compact log.
2024-06-18 08:15:26 -04:00
jax authors
5d35c99041 Merge pull request #21945 from hawkinsp:release
PiperOrigin-RevId: 644340066
2024-06-18 04:53:16 -07:00
jax authors
d7bc6b4dad Update XLA dependency to use revision
79fd5733f9.

PiperOrigin-RevId: 644336231
2024-06-18 04:35:01 -07:00
Peter Hawkins
3f4f79c83e Prepare for 0.4.30 release. 2024-06-18 07:34:45 -04:00
Sergei Lebedev
dfcfb36062 Pallas GPU no longer falls back to lax.pow for integer powers
Instead the lowering computes the power in a loop by squaring, similarly
to how we do it in the StableHLO lowering.

Fixes #21928.

PiperOrigin-RevId: 644313113
2024-06-18 02:54:39 -07:00
Sergei Lebedev
5bfd6afa80 Removed unnecessary skip in pallas_test.py::SoftmaxTest
The Triton bug, whatever it was, seems to have been fixed.

PiperOrigin-RevId: 644293465
2024-06-18 01:40:13 -07:00
Eugene Zhulenev
3fd9326881 [jax] Enable api_test with XLA:CPU thunks
PiperOrigin-RevId: 644268375
2024-06-17 23:58:02 -07:00
Jevin Jiang
ed4958cb3e [XLA:Mosaic] Add internal scratch VMEM
- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.

PiperOrigin-RevId: 644184896
2024-06-17 17:31:31 -07:00
Sharad Vikram
701c63e19a [Pallas/TPU] Add API for megacore partitioning of pipelines
PiperOrigin-RevId: 644184524
2024-06-17 17:28:07 -07:00
Justin Fu
fb68f3449b [Pallas] Add checkify support for pallas_call in interpret mode.
PiperOrigin-RevId: 644181742
2024-06-17 17:15:42 -07:00
Justin Fu
1d77720e9a [Pallas] Add initial DMA interpret mode rules. Currently this only supports LOGICAL device ids with one sharding axis.
PiperOrigin-RevId: 644171210
2024-06-17 16:36:21 -07:00
Justin Fu
9556a29c06 [Pallas] Fix warning for passing None into key in pallas PRNG.
PiperOrigin-RevId: 644156336
2024-06-17 15:41:42 -07:00
Sharad Vikram
9499de4358 [Pallas] Make num_programs return an int if the grid is not dynamic
PiperOrigin-RevId: 644149441
2024-06-17 15:18:40 -07:00
Dan Foreman-Mackey
1de2756c7e Fix build dependency in CUDA custom call example
PiperOrigin-RevId: 644138770
2024-06-17 14:42:46 -07:00
Yash Katariya
2daa0af496 Don't use deprecated backend=cpu in layout test
PiperOrigin-RevId: 644134679
2024-06-17 14:29:20 -07:00
jax authors
3044da3900 Merge pull request #21918 from google:dependabot/pip/scipy-1.13.1
PiperOrigin-RevId: 644131442
2024-06-17 14:18:45 -07:00
jax authors
f819c344ee Merge pull request #21645 from andportnoy:aportnoy/cuda-custom-call
PiperOrigin-RevId: 644124155
2024-06-17 13:56:52 -07:00
Junwhan Ahn
81f63ed19c Fix a bug in device_put lowering introduced by https://github.com/google/jax/pull/21754
Also adds a test that triggers the issue. Confirmed that the test fails without the fix.

PiperOrigin-RevId: 644117038
2024-06-17 13:33:35 -07:00
jax authors
683878322d Update XLA dependency to use revision
f468b46f3a.

PiperOrigin-RevId: 644110315
2024-06-17 13:12:04 -07:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
dependabot[bot]
b26af70949
Bump scipy from 1.13.0 to 1.13.1
Bumps [scipy](https://github.com/scipy/scipy) from 1.13.0 to 1.13.1.
- [Release notes](https://github.com/scipy/scipy/releases)
- [Commits](https://github.com/scipy/scipy/compare/v1.13.0...v1.13.1)

---
updated-dependencies:
- dependency-name: scipy
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-06-17 19:58:11 +00:00
jax authors
9e1f1388c1 Adjust test tolerances for Eigh and Svd to make tests pass on all platforms.
Add a test comparing eigenvalues to "reference eigenvalues" computed in double precision by Numpy.

PiperOrigin-RevId: 644104177
2024-06-17 12:52:45 -07:00
jax authors
b608c69b5e Merge pull request #21922 from hawkinsp:deps
PiperOrigin-RevId: 644103102
2024-06-17 12:48:22 -07:00
Peter Hawkins
160e09e235 Use NumPy 2.0.0 and SciPy 1.13.1 in builds.
Don't override the XLA repository in the nightly Windows CI builds,
which should be building JAX as it exists in the source repository.
2024-06-17 19:35:08 +00:00
Andrey Portnoy
ec5c4f5a10 Add CUDA custom call example as a JAX test 2024-06-17 15:21:49 -04:00
jax authors
c9b23f0f5a Merge pull request #21912 from superbobry:pallas
PiperOrigin-RevId: 644093480
2024-06-17 12:16:49 -07:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
jax authors
be1f4ba380 Merge pull request #21905 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Kyle Lucke
ebdafea9c8 Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status
PiperOrigin-RevId: 644063768
2024-06-17 10:51:55 -07:00
jax authors
21e2319e62 Merge pull request #21911 from gnecula:poly_solve_doc
PiperOrigin-RevId: 644062799
2024-06-17 10:51:41 -07:00
jax authors
f86cd6de56 Rewrite vector.multi_dim_reduction with bf16 source/accumulator/output into
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.

PiperOrigin-RevId: 644062786
2024-06-17 10:51:24 -07:00
jax authors
039b7c12e1 Merge pull request #21904 from tilakrayal:patch-2
PiperOrigin-RevId: 644061552
2024-06-17 10:48:05 -07:00
jax authors
2f22d3abfd Merge pull request #21901 from jjyyxx:patch-1
PiperOrigin-RevId: 644061544
2024-06-17 10:44:50 -07:00
Sergei Lebedev
01f182e772 Use `unitialized_value` for allocating outputs for interpreted Pallas kernels
PiperOrigin-RevId: 644057616
2024-06-17 10:34:38 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
George Necula
b1a8c65883 [shape_poly] Add documentation for workaround with dimension parameters. 2024-06-17 20:14:20 +03:00
Sergei Lebedev
550862f8c1 Added some docs to `_hoist_consts_to_refs`
I also restructured the implementation slightly, because most list allocations
were in fact unnecessary.
2024-06-17 15:33:05 +01:00
Adam Paszke
4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
Sergei Lebedev
f67f2e06ce Fixed a `ValueError` when a Pallas GPU kernel closed over array constants
The fix idea is based on the investigation by @zhixuan-lin in #21557.

PiperOrigin-RevId: 643965836
2024-06-17 05:05:01 -07:00
Sergei Lebedev
4913fff971 Rollback #21888, because it breaks multiple internal users
Reverts 193591b5c0b90ce498015b2e3d48950615253380

PiperOrigin-RevId: 643965549
2024-06-17 05:01:04 -07:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
tilakrayal
b461846cc0
Fixing the naming conventions in lax_numpy.py 2024-06-17 11:56:52 +05:30
jax authors
595a620804 Update XLA dependency to use revision
081a8b35e8.

PiperOrigin-RevId: 643808691
2024-06-16 13:09:08 -07:00
jax authors
ae0127d696 Merge pull request #21897 from matsen:patch-1
PiperOrigin-RevId: 643735130
2024-06-16 02:01:42 -07:00
Yuxuan Jiang
cd23b2c82c
Fix CUDNN requirement inconsistency in installation.md 2024-06-16 14:38:05 +08:00
jax authors
546a3a60eb Update XLA dependency to use revision
340fbbf18f.

PiperOrigin-RevId: 643651407
2024-06-15 13:52:48 -07:00
jax authors
993fa6108d Merge pull request #21880 from gnecula:doc_poly
PiperOrigin-RevId: 643631386
2024-06-15 11:03:18 -07:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Erick Matsen
b585eceeb0
Fixing installation docs: needed quotes; markdown fix 2024-06-15 05:12:21 -07:00