162 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
c662fd216d Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.

PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
Skye Wanderman-Milne
67d80c21cb Increase sharding count on nn_test and svd_test to avoid ASAN timeouts.
PiperOrigin-RevId: 527387005
2023-04-26 15:11:29 -07:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00
jax authors
db2cbd4ae8 Merge pull request #15665 from hawkinsp:sourceinfo
PiperOrigin-RevId: 525581713
2023-04-19 16:30:23 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
a3b262c379 Use the traceback of the call site when assigning a source location to an inlined function.
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Peter Hawkins
017548c40b Move implementation of compilation cache out of jax/experimental and into jax/_src.
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.

No functional change intended.

PiperOrigin-RevId: 524855263
2023-04-17 08:35:53 -07:00
Jake VanderPlas
8c8f50f688 Fix tolerance and shard_count for experimental_rnn_test
This should fix the current GPU test timeout.

PiperOrigin-RevId: 522167894
2023-04-05 15:19:19 -07:00
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Skye Wanderman-Milne
4cb3b011a0 Remove PJRT C API bypass.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Jieying Luo
b403c2a083 [PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
2023-03-21 16:57:34 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Peter Hawkins
a32a7ff903 Move _src/tree_util.py into a separate Bazel target.
Fix a type error in api.py revealed by the split.

PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
pizzud
04def0b6ab lazy_loader_module: Move to new internal_test_util directory.
Now we no longer need to mess with sys.path in lazy_loader_test.

PiperOrigin-RevId: 515674188
2023-03-10 10:29:33 -08:00
Peter Hawkins
01b00c4821 Increase sharding of shard_map test on CPU.
This test is timing out in CI with sanitizers enabled (asan/tsan).

PiperOrigin-RevId: 515369731
2023-03-09 10:13:26 -08:00
jax authors
59bf2061c4 Merge pull request #14565 from pizzud:deprecation-module
PiperOrigin-RevId: 515172435
2023-03-08 16:23:53 -08:00
jax authors
9c4db8c962 Merge pull request #14633 from mattjj:shmap-test-vmap
PiperOrigin-RevId: 515117185
2023-03-08 12:56:54 -08:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
jax authors
00f1abe401 Disable 2 failing jax tests.
PiperOrigin-RevId: 514515343
2023-03-06 13:50:40 -08:00
pizzud
ef28dcf091 lax_scipy_test: Split into three targets, take 2.
The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.

Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 514440062
2023-03-06 09:53:23 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
David Pizzuto
a8f2d9a186 deprecation_module: Move to new internal_test_util directory.
Now we no longer need to mess with sys.path in deprecation_test.
2023-02-17 10:55:04 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
Jake VanderPlas
6608242f95 sparse_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
69b8a03400 Disable some slow tests under asan.
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Ashish Shenoy
f71a55c554 Rename tensorflow core target variable to tensorflow_core
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
795c14b388 Merge pull request #14252 from jakevdp:sparse-conv
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Yash Katariya
e5b2c5ea44 Remove the jit_pjit_api_merge disable for api_test now that it is passing
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Yash Katariya
1ee21d121c Add pjit support in jax.experimental.jet
PiperOrigin-RevId: 504102287
2023-01-23 15:51:47 -08:00
Skye Wanderman-Milne
953910ab45 Disable timing out sparse_test.py on msan
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96 Increase sharding on checkify_test.py to avoid asan timeouts
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
Yash Katariya
4add3b8cee Make pjit an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
Skye Wanderman-Milne
6d0e22eaf9 Don't run FP8 dtype test on TPU.
This change makes dtypes_test.py pass even when not using Bazel (e.g. with
pytest). It also improves TPU coverage when using Bazel.

PiperOrigin-RevId: 502930531
2023-01-18 11:22:17 -08:00
Yash Katariya
05e1ddd4ea Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
85654ceeab Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00