268 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergey Kozub
84774b39e6 Fix sparse dot metadata loader
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.

PiperOrigin-RevId: 633513110
2024-05-14 03:08:18 -07:00
jax authors
8baa5d8180 Merge pull request #21128 from hawkinsp:loggingtest
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c Enable logging_test on all CI platforms.
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Roy Frostig
c18f7916a3 bump shard count for random_lax_test
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
Sergey Kozub
d655cd7be9 Disable sparse_nm_test_gpu_h100 because of flakiness
PiperOrigin-RevId: 628490449
2024-04-26 13:13:15 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
Adam Paszke
fa66e731e6 Increase sharding to avoid timeouts
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00
Meekail Zain
e6508a4f47 Add __array_namespace_info__ and corresponding utilities 2024-04-11 14:20:44 +00:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Cjkkkk
85cbe05f25 add multiaccelerator tag to test 2024-03-15 13:22:45 -07:00
Peter Hawkins
02ad457bc2 Disable P100 and A100 for pytorch_interoperability_test.
These are causing CI timeouts.

PiperOrigin-RevId: 615427575
2024-03-13 08:31:28 -07:00
jax authors
60bf38bde9 Merge pull request #20128 from shuhand0:dev/shuhan/ci2
PiperOrigin-RevId: 615231412
2024-03-12 17:45:19 -07:00
Shuhan Ding
5a93e15bd7
add to tests/BUILD 2024-03-12 17:17:20 -07:00
Roy Frostig
29edfd8925 define a loop-free untrue batching rule for rng_bit_generator 2024-03-08 13:13:03 -08:00
Benjamin Kramer
5005890546 Enable more tests on H100
20895965b2 fixed these

PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Peter Hawkins
6207977fac Disable some tests that fail on H100 in CI.
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Peter Hawkins
4ff84d04d9 Disable asan for torch interoperability test on CPU.
This is timing out at build time on CI.

PiperOrigin-RevId: 609350782
2024-02-22 06:29:13 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
jax authors
21236f0c65 Removes unused IREE config parameters from tests.
PiperOrigin-RevId: 606590491
2024-02-13 05:45:42 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Adam Paszke
1b2227283b Shard nn_test on GPU to avoid timeouts
PiperOrigin-RevId: 606224790
2024-02-12 05:49:26 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
jax authors
07d793d7b2 Enable pmap_test under ASAN on GPU.
PiperOrigin-RevId: 605553673
2024-02-09 00:57:19 -08:00
jax authors
8bbcbb6e12 Merge pull request #19532 from mattjj:jax-attrs2
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
dfda948fbf [XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 600570354
2024-01-22 14:25:47 -08:00
Peter Hawkins
25e4acfe25 Disabled lax_scipy_special_functions_test under ASAN on GPU.
This test is slow and times out in CI.

PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
Peter Hawkins
912a5ef771 Disable some tests that time out or OOM under ASAN.
PiperOrigin-RevId: 598543036
2024-01-15 01:40:44 -08:00
Peter Hawkins
2980e8f09c [JAX] Disable some CUDA tests that fail under ASAN, due to bugs in NCCL and Triton.
PiperOrigin-RevId: 597845384
2024-01-12 08:20:36 -08:00
Peter Hawkins
35fc2ed8e0 Disable ASAN for several CUDA tests.
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
Peter Hawkins
67d5c3bdea [JAX:GPU] Add a test that verifies that the XLA_PYTHON_CLIENT_PREALLOCATE environment variable is parsed correctly.
Fixes https://github.com/google/jax/issues/19035

PiperOrigin-RevId: 592322040
2023-12-19 13:06:08 -08:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
b225c86f10 Merge pull request #18262 from jakevdp:key-reuse-jaxpr
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00