267 Commits

Author SHA1 Message Date
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergey Kozub
84774b39e6 Fix sparse dot metadata loader
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.

PiperOrigin-RevId: 633513110
2024-05-14 03:08:18 -07:00
jax authors
8baa5d8180 Merge pull request #21128 from hawkinsp:loggingtest
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c Enable logging_test on all CI platforms.
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Roy Frostig
c18f7916a3 bump shard count for random_lax_test
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
Sergey Kozub
d655cd7be9 Disable sparse_nm_test_gpu_h100 because of flakiness
PiperOrigin-RevId: 628490449
2024-04-26 13:13:15 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
Adam Paszke
fa66e731e6 Increase sharding to avoid timeouts
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00
Meekail Zain
e6508a4f47 Add __array_namespace_info__ and corresponding utilities 2024-04-11 14:20:44 +00:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Cjkkkk
85cbe05f25 add multiaccelerator tag to test 2024-03-15 13:22:45 -07:00
Peter Hawkins
02ad457bc2 Disable P100 and A100 for pytorch_interoperability_test.
These are causing CI timeouts.

PiperOrigin-RevId: 615427575
2024-03-13 08:31:28 -07:00
jax authors
60bf38bde9 Merge pull request #20128 from shuhand0:dev/shuhan/ci2
PiperOrigin-RevId: 615231412
2024-03-12 17:45:19 -07:00
Shuhan Ding
5a93e15bd7
add to tests/BUILD 2024-03-12 17:17:20 -07:00
Roy Frostig
29edfd8925 define a loop-free untrue batching rule for rng_bit_generator 2024-03-08 13:13:03 -08:00
Benjamin Kramer
5005890546 Enable more tests on H100
20895965b2 fixed these

PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Peter Hawkins
6207977fac Disable some tests that fail on H100 in CI.
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Peter Hawkins
4ff84d04d9 Disable asan for torch interoperability test on CPU.
This is timing out at build time on CI.

PiperOrigin-RevId: 609350782
2024-02-22 06:29:13 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
jax authors
21236f0c65 Removes unused IREE config parameters from tests.
PiperOrigin-RevId: 606590491
2024-02-13 05:45:42 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Adam Paszke
1b2227283b Shard nn_test on GPU to avoid timeouts
PiperOrigin-RevId: 606224790
2024-02-12 05:49:26 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
jax authors
07d793d7b2 Enable pmap_test under ASAN on GPU.
PiperOrigin-RevId: 605553673
2024-02-09 00:57:19 -08:00
jax authors
8bbcbb6e12 Merge pull request #19532 from mattjj:jax-attrs2
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
dfda948fbf [XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 600570354
2024-01-22 14:25:47 -08:00
Peter Hawkins
25e4acfe25 Disabled lax_scipy_special_functions_test under ASAN on GPU.
This test is slow and times out in CI.

PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
Peter Hawkins
912a5ef771 Disable some tests that time out or OOM under ASAN.
PiperOrigin-RevId: 598543036
2024-01-15 01:40:44 -08:00
Peter Hawkins
2980e8f09c [JAX] Disable some CUDA tests that fail under ASAN, due to bugs in NCCL and Triton.
PiperOrigin-RevId: 597845384
2024-01-12 08:20:36 -08:00
Peter Hawkins
35fc2ed8e0 Disable ASAN for several CUDA tests.
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
Peter Hawkins
67d5c3bdea [JAX:GPU] Add a test that verifies that the XLA_PYTHON_CLIENT_PREALLOCATE environment variable is parsed correctly.
Fixes https://github.com/google/jax/issues/19035

PiperOrigin-RevId: 592322040
2023-12-19 13:06:08 -08:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
b225c86f10 Merge pull request #18262 from jakevdp:key-reuse-jaxpr
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Yash Katariya
78e0e6b058 Internal change
PiperOrigin-RevId: 588856820
2023-12-07 11:30:54 -08:00