246 Commits

Author SHA1 Message Date
Roy Frostig
29edfd8925 define a loop-free untrue batching rule for rng_bit_generator 2024-03-08 13:13:03 -08:00
Benjamin Kramer
5005890546 Enable more tests on H100
20895965b2 fixed these

PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Peter Hawkins
6207977fac Disable some tests that fail on H100 in CI.
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Peter Hawkins
4ff84d04d9 Disable asan for torch interoperability test on CPU.
This is timing out at build time on CI.

PiperOrigin-RevId: 609350782
2024-02-22 06:29:13 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
jax authors
21236f0c65 Removes unused IREE config parameters from tests.
PiperOrigin-RevId: 606590491
2024-02-13 05:45:42 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Adam Paszke
1b2227283b Shard nn_test on GPU to avoid timeouts
PiperOrigin-RevId: 606224790
2024-02-12 05:49:26 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
jax authors
07d793d7b2 Enable pmap_test under ASAN on GPU.
PiperOrigin-RevId: 605553673
2024-02-09 00:57:19 -08:00
jax authors
8bbcbb6e12 Merge pull request #19532 from mattjj:jax-attrs2
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
dfda948fbf [XLA:Python] Fail with an AttributeError if __cuda_array_interface__ is called on a sharded array.
Fixes https://github.com/google/jax/issues/19134

PiperOrigin-RevId: 600570354
2024-01-22 14:25:47 -08:00
Peter Hawkins
25e4acfe25 Disabled lax_scipy_special_functions_test under ASAN on GPU.
This test is slow and times out in CI.

PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
Peter Hawkins
912a5ef771 Disable some tests that time out or OOM under ASAN.
PiperOrigin-RevId: 598543036
2024-01-15 01:40:44 -08:00
Peter Hawkins
2980e8f09c [JAX] Disable some CUDA tests that fail under ASAN, due to bugs in NCCL and Triton.
PiperOrigin-RevId: 597845384
2024-01-12 08:20:36 -08:00
Peter Hawkins
35fc2ed8e0 Disable ASAN for several CUDA tests.
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
Peter Hawkins
67d5c3bdea [JAX:GPU] Add a test that verifies that the XLA_PYTHON_CLIENT_PREALLOCATE environment variable is parsed correctly.
Fixes https://github.com/google/jax/issues/19035

PiperOrigin-RevId: 592322040
2023-12-19 13:06:08 -08:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
b225c86f10 Merge pull request #18262 from jakevdp:key-reuse-jaxpr
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Yash Katariya
78e0e6b058 Internal change
PiperOrigin-RevId: 588856820
2023-12-07 11:30:54 -08:00
Peter Hawkins
5b97960d31 Disable sanitizer builds of shape_poly_test.
These take a very long time and sometimes timeout so it's probably not worth running them in CI.

PiperOrigin-RevId: 587768399
2023-12-04 10:42:56 -08:00
Peter Hawkins
1d95e79fd9 Disable export_harnesses_multi_platform_test under sanitizers.
This test appears to hit some sort of LLVM bug on Sapphire Rapids CPUs.

PiperOrigin-RevId: 587719850
2023-12-04 07:54:35 -08:00
Adam Paszke
ef65ba8f32 Internal change
PiperOrigin-RevId: 586312803
2023-11-29 05:47:50 -08:00
Adam Paszke
95f32b9bc2 Internal change
PiperOrigin-RevId: 585961348
2023-11-28 06:52:11 -08:00
Jake VanderPlas
0d073a4862 [array-api] add simple smoketest target for standard CI testing 2023-11-27 15:20:41 -08:00
George Necula
1ce90e6088 Turn off asan for shape_poly_test (times out)
PiperOrigin-RevId: 584807749
2023-11-22 23:41:37 -08:00
Peter Hawkins
fd9a1a2c36 Disable export_harnesses_test under asan.
This test times out in CI.

PiperOrigin-RevId: 584022342
2023-11-20 07:36:42 -08:00
George Necula
4fbf50dd60 [shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.

For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).

Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.

PiperOrigin-RevId: 583816243
2023-11-19 09:00:04 -08:00
George Necula
3601b25899 Move multi_platform_export_test.py out of jax2tf.
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.

We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.

PiperOrigin-RevId: 583603863
2023-11-18 02:52:44 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
jax authors
d03bbc0d0f random_lax_test: Bump shards number for CPU config.
PiperOrigin-RevId: 574239793
2023-10-17 12:59:12 -07:00
Peter Hawkins
89b5449882 [XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
2023-10-16 16:02:22 -07:00
jax authors
8f911e1512 random_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 571146766
2023-10-05 15:28:51 -07:00
jax authors
a2b70e3346 Bump shard_count for shard_map_test to fix timeouts.
PiperOrigin-RevId: 571109311
2023-10-05 13:18:10 -07:00
jax authors
1c37f5091c sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
2023-10-04 19:59:03 -07:00
jax authors
305efe0501 random_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 570781641
2023-10-04 13:04:44 -07:00
Tao Wang
c12929b012 Add more API set up for Mock GPU client. Also clean up previous mock GPU client
API.

PiperOrigin-RevId: 570153877
2023-10-02 13:11:29 -07:00
Yash Katariya
a32ed7e002 Bump shard_count for shard_map_test to fix the asan failures
PiperOrigin-RevId: 569520202
2023-09-29 10:02:38 -07:00
Matthew Johnson
a9dc3c1ea3 [shard_map] internal change to shard_map CI testing
PiperOrigin-RevId: 569036873
2023-09-27 20:06:24 -07:00