377 Commits

Author SHA1 Message Date
jax authors
95e2c17b61 Update test of QDWH to use stricter tolerances and test more shapes and types.
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.

PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Eugene Zhulenev
4cca233522 [xla:cpu] Add support for outfeed to thunk runtime
+ enabled infeed/outfeed test in jax

PiperOrigin-RevId: 640026208
2024-06-03 22:43:42 -07:00
Eugene Zhulenev
1aa2b6ee4f [jax:cpu] Add a test config to run JAX with new XLA:CPU runtime
XLA:CPU is migrating from compiling monolithic LLVM function for the whole HLO module to a thin runtime with separate functions for each kernel (fusion, individual operation, library call, etc.). While new runtime is not enabled by default we will use explicit opt-in on tests that are already compatible.

This tag will be removed after XLA:CPU will switch to the new runtime by default.

PiperOrigin-RevId: 640022517
2024-06-03 22:28:18 -07:00
Dan Foreman-Mackey
690fa1d90c Remove failing ffi test
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
2024-05-31 15:36:33 -04:00
jax authors
d9f07d0350 Merge pull request #21531 from dfm:move-ffi-submodule
PiperOrigin-RevId: 639077602
2024-05-31 10:31:32 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Peter Hawkins
01b4cb6de0 Bump memory limits for array_test and layout_test on TPU CI.
These use more than our CI's default memory limit (12GB) when run under tsan.

PiperOrigin-RevId: 638618718
2024-05-30 05:33:14 -07:00
Peter Hawkins
1e01fa7b0f Bump up memory requirements for pmap_test on TPU.
This test recently started exceeding the default memory requirement when run under tsan. I'm not entirely sure why, but perhaps some change pushed it just over our CI's 12GB default limit.

PiperOrigin-RevId: 636910434
2024-05-24 07:25:30 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergey Kozub
84774b39e6 Fix sparse dot metadata loader
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.

PiperOrigin-RevId: 633513110
2024-05-14 03:08:18 -07:00
jax authors
8baa5d8180 Merge pull request #21128 from hawkinsp:loggingtest
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c Enable logging_test on all CI platforms.
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Roy Frostig
c18f7916a3 bump shard count for random_lax_test
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
Sergey Kozub
d655cd7be9 Disable sparse_nm_test_gpu_h100 because of flakiness
PiperOrigin-RevId: 628490449
2024-04-26 13:13:15 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
Adam Paszke
fa66e731e6 Increase sharding to avoid timeouts
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00
Meekail Zain
e6508a4f47 Add __array_namespace_info__ and corresponding utilities 2024-04-11 14:20:44 +00:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Cjkkkk
85cbe05f25 add multiaccelerator tag to test 2024-03-15 13:22:45 -07:00
Peter Hawkins
02ad457bc2 Disable P100 and A100 for pytorch_interoperability_test.
These are causing CI timeouts.

PiperOrigin-RevId: 615427575
2024-03-13 08:31:28 -07:00
jax authors
60bf38bde9 Merge pull request #20128 from shuhand0:dev/shuhan/ci2
PiperOrigin-RevId: 615231412
2024-03-12 17:45:19 -07:00
Shuhan Ding
5a93e15bd7
add to tests/BUILD 2024-03-12 17:17:20 -07:00
Roy Frostig
29edfd8925 define a loop-free untrue batching rule for rng_bit_generator 2024-03-08 13:13:03 -08:00
Benjamin Kramer
5005890546 Enable more tests on H100
20895965b2 fixed these

PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Peter Hawkins
6207977fac Disable some tests that fail on H100 in CI.
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Peter Hawkins
4ff84d04d9 Disable asan for torch interoperability test on CPU.
This is timing out at build time on CI.

PiperOrigin-RevId: 609350782
2024-02-22 06:29:13 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
jax authors
21236f0c65 Removes unused IREE config parameters from tests.
PiperOrigin-RevId: 606590491
2024-02-13 05:45:42 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Adam Paszke
1b2227283b Shard nn_test on GPU to avoid timeouts
PiperOrigin-RevId: 606224790
2024-02-12 05:49:26 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
jax authors
07d793d7b2 Enable pmap_test under ASAN on GPU.
PiperOrigin-RevId: 605553673
2024-02-09 00:57:19 -08:00
jax authors
8bbcbb6e12 Merge pull request #19532 from mattjj:jax-attrs2
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
dfda948fbf [XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 600570354
2024-01-22 14:25:47 -08:00
Peter Hawkins
25e4acfe25 Disabled lax_scipy_special_functions_test under ASAN on GPU.
This test is slow and times out in CI.

PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00