jax authors
e8b06ccf56
Cholesky rank-1 update kernel for JAX.
...
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergey Kozub
84774b39e6
Fix sparse dot metadata loader
...
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.
PiperOrigin-RevId: 633513110
2024-05-14 03:08:18 -07:00
jax authors
8baa5d8180
Merge pull request #21128 from hawkinsp:loggingtest
...
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
11da3df238
Merge pull request #21096 from gspschmid:gschmid/sourcemaps
...
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c
Enable logging_test on all CI platforms.
...
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Peter Hawkins
d014f5dc5f
Compute source maps when pretty-printing jaxprs.
...
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
jax authors
70f2ef211f
Merge pull request #20971 from google:mutable-array-scan
...
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550
Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
...
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Roy Frostig
c18f7916a3
bump shard count for random_lax_test
...
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
Sergey Kozub
d655cd7be9
Disable sparse_nm_test_gpu_h100 because of flakiness
...
PiperOrigin-RevId: 628490449
2024-04-26 13:13:15 -07:00
Jake VanderPlas
beb49af678
sparse_nm_test: skip on incompatible GPUs
...
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f
Add JAX API that provides sparse matmul support (2:4 structured sparsity)
...
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))
where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type
If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.
PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
Adam Paszke
fa66e731e6
Increase sharding to avoid timeouts
...
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00
Meekail Zain
e6508a4f47
Add __array_namespace_info__ and corresponding utilities
2024-04-11 14:20:44 +00:00
Michael Hudgins
023930decf
Fix some load orderings for buildifier
...
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Cjkkkk
85cbe05f25
add multiaccelerator tag to test
2024-03-15 13:22:45 -07:00
Peter Hawkins
02ad457bc2
Disable P100 and A100 for pytorch_interoperability_test.
...
These are causing CI timeouts.
PiperOrigin-RevId: 615427575
2024-03-13 08:31:28 -07:00
jax authors
60bf38bde9
Merge pull request #20128 from shuhand0:dev/shuhan/ci2
...
PiperOrigin-RevId: 615231412
2024-03-12 17:45:19 -07:00
Shuhan Ding
5a93e15bd7
add to tests/BUILD
2024-03-12 17:17:20 -07:00
Roy Frostig
29edfd8925
define a loop-free untrue batching rule for rng_bit_generator
2024-03-08 13:13:03 -08:00
Benjamin Kramer
5005890546
Enable more tests on H100
...
20895965b2
fixed these
PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Peter Hawkins
6207977fac
Disable some tests that fail on H100 in CI.
...
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Peter Hawkins
4ff84d04d9
Disable asan for torch interoperability test on CPU.
...
This is timing out at build time on CI.
PiperOrigin-RevId: 609350782
2024-02-22 06:29:13 -08:00
Sergei Lebedev
57e59eb6c3
Removed deprecated jax.config methods and jax.config.config
...
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7
PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Sergei Lebedev
37f313ab22
Fixed internal CI builds
...
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Thomas Köppe
dcc65e621e
Reverts b506fee9e389391efb1336bc7575dba913e75cdf
...
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3
Removed deprecated jax.config methods and jax.config.config
...
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81
PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835
Reverts 318a19a89387caebd116168c4e47592e7d71ca65
...
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893
Removed deprecated jax.config methods
...
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
jax authors
21236f0c65
Removes unused IREE config parameters from tests.
...
PiperOrigin-RevId: 606590491
2024-02-13 05:45:42 -08:00
jax authors
7b05bbdda0
Merge pull request #18814 from Cjkkkk:spda
...
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Adam Paszke
1b2227283b
Shard nn_test on GPU to avoid timeouts
...
PiperOrigin-RevId: 606224790
2024-02-12 05:49:26 -08:00
Cjkkkk
5708fb955b
address some format issues
2024-02-09 09:05:08 -08:00
jax authors
07d793d7b2
Enable pmap_test under ASAN on GPU.
...
PiperOrigin-RevId: 605553673
2024-02-09 00:57:19 -08:00
jax authors
8bbcbb6e12
Merge pull request #19532 from mattjj:jax-attrs2
...
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
4a8babb101
integrate attrs in jax.jit
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
dfda948fbf
[XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
...
Fixes https://github.com/google/jax/issues/19134
PiperOrigin-RevId: 600570354
2024-01-22 14:25:47 -08:00
Peter Hawkins
25e4acfe25
Disabled lax_scipy_special_functions_test under ASAN on GPU.
...
This test is slow and times out in CI.
PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
Peter Hawkins
912a5ef771
Disable some tests that time out or OOM under ASAN.
...
PiperOrigin-RevId: 598543036
2024-01-15 01:40:44 -08:00
Peter Hawkins
2980e8f09c
[JAX] Disable some CUDA tests that fail under ASAN, due to bugs in NCCL and Triton.
...
PiperOrigin-RevId: 597845384
2024-01-12 08:20:36 -08:00
Peter Hawkins
35fc2ed8e0
Disable ASAN for several CUDA tests.
...
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
George Necula
ed2a839884
Move export backwards compatibility tests out of jax2tf. Step 3.
...
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.
PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
Yash Katariya
72fbdb2eb5
Expose shard_alike
via jax.experimental. The API is x, y = shard_like(x, y)
.
...
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.
The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.
The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.
Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
Peter Hawkins
67d5c3bdea
[JAX:GPU] Add a test that verifies that the XLA_PYTHON_CLIENT_PREALLOCATE environment variable is parsed correctly.
...
Fixes https://github.com/google/jax/issues/19035
PiperOrigin-RevId: 592322040
2023-12-19 13:06:08 -08:00
Parker Schuh
7ba8622719
For custom_partitioning, directly emit call when inside of a shard_map.
...
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
b225c86f10
Merge pull request #18262 from jakevdp:key-reuse-jaxpr
...
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
Jake VanderPlas
a52d18781e
Add experimental static key reuse checking
2023-12-11 12:03:48 -08:00
Yash Katariya
78e0e6b058
Internal change
...
PiperOrigin-RevId: 588856820
2023-12-07 11:30:54 -08:00
Peter Hawkins
5b97960d31
Disable sanitizer builds of shape_poly_test.
...
These take a very long time and sometimes timeout so it's probably not worth running them in CI.
PiperOrigin-RevId: 587768399
2023-12-04 10:42:56 -08:00
Peter Hawkins
1d95e79fd9
Disable export_harnesses_multi_platform_test under sanitizers.
...
This test appears to hit some sort of LLVM bug on Sapphire Rapids CPUs.
PiperOrigin-RevId: 587719850
2023-12-04 07:54:35 -08:00