302 Commits

Author SHA1 Message Date
Ilia Sergachev
85d792a92d Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. 2024-09-05 01:25:54 +02:00
jax authors
2f3990d13c Remove CPU test variant.
PiperOrigin-RevId: 669359594
2024-08-30 09:58:32 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
George Necula
dbd6aeebb7 Disable some asan tests, times out
PiperOrigin-RevId: 662774152
2024-08-13 22:03:29 -07:00
Sergei Lebedev
d8eafc8ee3 Disabled nn_test under asan on TPU as well, since it also times out
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Sergei Lebedev
3a1567f57a Do not run nn_test under asan -- it times out
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
Vladimir Belitskiy
7f96b263d4 Un-skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on ASAN.
It should have been fixed via
https://github.com/pytorch/pytorch/issues/117058#issuecomment-1973020150

PiperOrigin-RevId: 656464550
2024-07-26 11:10:41 -07:00
Vladimir Belitskiy
282ebf4882 Skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on MSAN.
MSAN has issues with using `-c opt` in some cases, which prevents this
test from running properly.

PiperOrigin-RevId: 656454585
2024-07-26 10:44:19 -07:00
Vladimir Belitskiy
ba50e77407 Increase shard count for //third_party/py/jax/tests:lax_numpy_ufuncs_test_cpu.
PiperOrigin-RevId: 655946922
2024-07-25 07:26:06 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Eugene Zhulenev
e3fc63cafb [xla:cpu] Support for up to 16 sorted inputs
+ enable more jax/lax tests for XLA CPU thunks

PiperOrigin-RevId: 655249641
2024-07-23 11:54:31 -07:00
Vladimir Belitskiy
a1f2a50cfa Increase shard count under TPU for //third_party/py/jax/tests:lax_numpy_test.
PiperOrigin-RevId: 654847718
2024-07-22 12:08:04 -07:00
Eugene Zhulenev
e23de7d790 [xla:cpu] Add support for XLA:CPU+Thunks compilation cache
+ fix custom call thunk for tuple result of size 1

PiperOrigin-RevId: 653381010
2024-07-17 15:29:13 -07:00
Gleb Pobudzey
46103f6ff3 Updated the repr of GPU devices to be more consistent with TPUs/CPUs.
For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
2024-07-11 06:54:20 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Ayaka
6c05aa2f32 Clean up 2024-07-04 17:16:32 +04:00
Adam Paszke
727d120401 Bump up the shard_count for GPU FFT tests
They seem to be timing out with ASAN and no sharding.

PiperOrigin-RevId: 648301571
2024-07-01 02:58:23 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
24b42eed5e [export] Clean up BUILD targets for jax.experimental.export
jax.experimental.export is deprecated and will be removed in a future version of JAX.

See migration guide at: https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export

PiperOrigin-RevId: 647562073
2024-06-27 23:08:48 -07:00
jax authors
8b22b5fbcc Merge pull request #22087 from jakevdp:validate-config
PiperOrigin-RevId: 646963494
2024-06-26 08:51:34 -07:00
Jake VanderPlas
fa73077146 jax.config: validate on set() 2024-06-25 09:02:32 -07:00
Justin Fu
8ba8f3bf65 [Pallas] Implement block-invariant sampling.
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Eugene Zhulenev
3fd9326881 [jax] Enable api_test with XLA:CPU thunks
PiperOrigin-RevId: 644268375
2024-06-17 23:58:02 -07:00
Peter Hawkins
339027d7ab [JAX] Disable qdwh_test in asan/msan/tsan configurations on TPU.
This test is flakily timing out in CI, the sanitizers probably push the test over its time bound.

PiperOrigin-RevId: 642695381
2024-06-12 12:16:50 -07:00
jax authors
95e2c17b61 Update test of QDWH to use stricter tolerances and test more shapes and types.
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.

PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Eugene Zhulenev
4cca233522 [xla:cpu] Add support for outfeed to thunk runtime
+ enabled infeed/outfeed test in jax

PiperOrigin-RevId: 640026208
2024-06-03 22:43:42 -07:00
Eugene Zhulenev
1aa2b6ee4f [jax:cpu] Add a test config to run JAX with new XLA:CPU runtime
XLA:CPU is migrating from compiling monolithic LLVM function for the whole HLO module to a thin runtime with separate functions for each kernel (fusion, individual operation, library call, etc.). While new runtime is not enabled by default we will use explicit opt-in on tests that are already compatible.

This tag will be removed after XLA:CPU will switch to the new runtime by default.

PiperOrigin-RevId: 640022517
2024-06-03 22:28:18 -07:00
Dan Foreman-Mackey
690fa1d90c Remove failing ffi test
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
2024-05-31 15:36:33 -04:00
jax authors
d9f07d0350 Merge pull request #21531 from dfm:move-ffi-submodule
PiperOrigin-RevId: 639077602
2024-05-31 10:31:32 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Peter Hawkins
01b4cb6de0 Bump memory limits for array_test and layout_test on TPU CI.
These use more than our CI's default memory limit (12GB) when run under tsan.

PiperOrigin-RevId: 638618718
2024-05-30 05:33:14 -07:00
Peter Hawkins
1e01fa7b0f Bump up memory requirements for pmap_test on TPU.
This test recently started exceeding the default memory requirement when run under tsan. I'm not entirely sure why, but perhaps some change pushed it just over our CI's 12GB default limit.

PiperOrigin-RevId: 636910434
2024-05-24 07:25:30 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergey Kozub
84774b39e6 Fix sparse dot metadata loader
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.

PiperOrigin-RevId: 633513110
2024-05-14 03:08:18 -07:00
jax authors
8baa5d8180 Merge pull request #21128 from hawkinsp:loggingtest
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c Enable logging_test on all CI platforms.
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Roy Frostig
c18f7916a3 bump shard count for random_lax_test
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
Sergey Kozub
d655cd7be9 Disable sparse_nm_test_gpu_h100 because of flakiness
PiperOrigin-RevId: 628490449
2024-04-26 13:13:15 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
Adam Paszke
fa66e731e6 Increase sharding to avoid timeouts
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00