211 Commits

Author SHA1 Message Date
Peter Hawkins
fd9a1a2c36 Disable export_harnesses_test under asan.
This test times out in CI.

PiperOrigin-RevId: 584022342
2023-11-20 07:36:42 -08:00
George Necula
4fbf50dd60 [shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.

For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).

Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.

PiperOrigin-RevId: 583816243
2023-11-19 09:00:04 -08:00
George Necula
3601b25899 Move multi_platform_export_test.py out of jax2tf.
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.

We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.

PiperOrigin-RevId: 583603863
2023-11-18 02:52:44 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
jax authors
d03bbc0d0f random_lax_test: Bump shards number for CPU config.
PiperOrigin-RevId: 574239793
2023-10-17 12:59:12 -07:00
Peter Hawkins
89b5449882 [XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
2023-10-16 16:02:22 -07:00
jax authors
8f911e1512 random_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 571146766
2023-10-05 15:28:51 -07:00
jax authors
a2b70e3346 Bump shard_count for shard_map_test to fix timeouts.
PiperOrigin-RevId: 571109311
2023-10-05 13:18:10 -07:00
jax authors
1c37f5091c sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
2023-10-04 19:59:03 -07:00
jax authors
305efe0501 random_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 570781641
2023-10-04 13:04:44 -07:00
Tao Wang
c12929b012 Add more API set up for Mock GPU client. Also clean up previous mock GPU client
API.

PiperOrigin-RevId: 570153877
2023-10-02 13:11:29 -07:00
Yash Katariya
a32ed7e002 Bump shard_count for shard_map_test to fix the asan failures
PiperOrigin-RevId: 569520202
2023-09-29 10:02:38 -07:00
Matthew Johnson
a9dc3c1ea3 [shard_map] internal change to shard_map CI testing
PiperOrigin-RevId: 569036873
2023-09-27 20:06:24 -07:00
Hyeontaek Lim
f0bde75dd3 [JAX] Export shard_map_test for testing on additional JAX backends
PiperOrigin-RevId: 567522898
2023-09-21 22:52:36 -07:00
Sharad Vikram
afb6691885 Disable msan/tsan for xmap_tests thanks to timeouts
PiperOrigin-RevId: 567412260
2023-09-21 14:24:08 -07:00
Junwhan Ahn
6a551a1efa Add memories_test.py to the list of exported tests
PiperOrigin-RevId: 567375604
2023-09-21 11:57:09 -07:00
Peter Hawkins
f863cfbaad Relax some test tolerances to fix failures on Linux aarch64.
PiperOrigin-RevId: 565930178
2023-09-16 06:55:22 -07:00
Peter Hawkins
bbfba9ace8 Remove code that disabled tests on "stream_executor" backends.
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.

Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.

PiperOrigin-RevId: 565367453
2023-09-14 07:52:43 -07:00
Peter Hawkins
306c60d4c7 Remove references to deprecated "tpu_se" build configuration.
PiperOrigin-RevId: 565156675
2023-09-13 14:10:30 -07:00
Yash Katariya
76a5dc3cac Move memories_test.py to JAX
PiperOrigin-RevId: 564551723
2023-09-11 17:41:55 -07:00
George Necula
660a015652 [export] Move jax_export and shape_poly out of jax2tf.
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.

We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.

PiperOrigin-RevId: 562988740
2023-09-05 22:15:59 -07:00
George Necula
e0a6230214 [host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
2023-08-31 22:08:23 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
George Necula
26f091e446 [callback] Disable stream_executor tests.
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Adam Paszke
0228bf7d3c Fix MSAN errors in cache_key_test
The device_assignment array was never initialized, causing MSAN errors.
Replacing it with np.arange fixes the issue.

PiperOrigin-RevId: 553469463
2023-08-03 07:28:32 -07:00
jax authors
9e7502ce60 Disable MSAN testing for cache-key unit tests.
This is an existing issue with the compilation cache tests.
The refactoring of the cache key generation part into a separate
file requires tagging the refactored tests also.

PiperOrigin-RevId: 551972670
2023-07-28 16:02:12 -07:00
Skye Wanderman-Milne
8b58e38ec5 Add jax_debug_log_modules config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).

Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-07-28 18:11:12 +00:00
jax authors
3b28d4e180 Refactor Jax compilation cache key generation.
This is in preparation for introducing a more robust key-generation
algorithm.

This refactoring does not introduce any change in behavior.

Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892
2023-07-27 23:01:00 -07:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00
Tao Wang
b7686f41aa Enable passing fdo_profile in compiler_options in pxla.py
PiperOrigin-RevId: 549109629
2023-07-18 14:18:28 -07:00
Peter Hawkins
3fcf72af32 Disable tests that are flaky in CI.
PiperOrigin-RevId: 547464866
2023-07-12 05:13:38 -07:00
jax authors
2fa6a9c9bf Allow other backends to run the array_test.py test.
PiperOrigin-RevId: 547191886
2023-07-11 08:05:25 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Peter Hawkins
32026ad18b Disable random_test_with_custom_prng on CPU under msan.
This test flakily times out in CI.

PiperOrigin-RevId: 535293997
2023-05-25 10:10:01 -07:00
Peter Hawkins
1d20d2f301 Increase sharding of host_callback_test on TPU to fix CI flakiness.
PiperOrigin-RevId: 533451822
2023-05-19 07:44:53 -07:00
Peter Hawkins
e6628e2e72 Disable tests that time out in CI.
PiperOrigin-RevId: 532792740
2023-05-17 08:16:07 -07:00
Peter Hawkins
9471bb3045 Disable sparsify_test on CPU under tsan.
Under tsan this test times out in CI.

PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
Peter Hawkins
e8c735125c Disable more tests that are flaky in CI.
PiperOrigin-RevId: 529724306
2023-05-05 08:33:33 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
Peter Hawkins
09fce87f54 Increase sharding of or disable some flaky CI tests.
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Skye Wanderman-Milne
70cac773f7 Exclude scipy_fft_test from msan as well as t/asan.
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882 Bump up lax_test TPU sharding to avoid asan timeouts
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Skye Wanderman-Milne
c662fd216d Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.

PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00