137 Commits

Author SHA1 Message Date
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
Jake VanderPlas
6608242f95 sparse_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
69b8a03400 Disable some slow tests under asan.
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Ashish Shenoy
f71a55c554 Rename tensorflow core target variable to tensorflow_core
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
795c14b388 Merge pull request #14252 from jakevdp:sparse-conv
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Yash Katariya
e5b2c5ea44 Remove the jit_pjit_api_merge disable for api_test now that it is passing
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Yash Katariya
1ee21d121c Add pjit support in jax.experimental.jet
PiperOrigin-RevId: 504102287
2023-01-23 15:51:47 -08:00
Skye Wanderman-Milne
953910ab45 Disable timing out sparse_test.py on msan
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96 Increase sharding on checkify_test.py to avoid asan timeouts
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
Yash Katariya
4add3b8cee Make pjit an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
Skye Wanderman-Milne
6d0e22eaf9 Don't run FP8 dtype test on TPU.
This change makes dtypes_test.py pass even when not using Bazel (e.g. with
pytest). It also improves TPU coverage when using Bazel.

PiperOrigin-RevId: 502930531
2023-01-18 11:22:17 -08:00
Yash Katariya
05e1ddd4ea Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
85654ceeab Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
Yash Katariya
4601928277 Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Qiao Zhang
d203926c16 Expose fp8 in jax dtypes and mlir builder.
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Yash Katariya
5eb23a7615 Fix name_stack usage of pjit. Now all the metadata of transformations in hlo are correct.
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Yash Katariya
649ee1be34 Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5 Fix debugging primitives for pjit. This came up during jit/pjit merge
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Sharad Vikram
f729da4a36 Add shards for checkify_test on GPU
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
Yash Katariya
66aafb6e16 Don't take the cpp dispatch path for pjit if it contains ordered effects just like jit.
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1 Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7 Fix debug nans test after merging jit and pjit codepaths
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Yash Katariya
849af498d1 Make jaxpr_util_test work with jit/pjit merge
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Adam Paszke
904cd4b98d Internal change
PiperOrigin-RevId: 499812920
2023-01-05 04:13:34 -08:00
Yash Katariya
c3bb26050c Add pjit rule to sparse_rules to support pjit. This is done to merge the jit and pjit API.
PiperOrigin-RevId: 499311841
2023-01-03 14:13:19 -08:00
Peter Hawkins
401fbb61a9 Disable xmap_test on TPU under asan due to CI timeouts.
PiperOrigin-RevId: 492994226
2022-12-05 06:52:09 -08:00
Peter Hawkins
7495a9e370 [JAX] Enable/disable tests that timed out in CI.
Reenable pmap_test since it was recently sped up.

PiperOrigin-RevId: 491650701
2022-11-29 09:02:16 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Adam Paszke
fe56a19904 Shard fft tests to avoid timeouts
PiperOrigin-RevId: 490486632
2022-11-23 06:33:13 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
61aa415356 Disable sparse_test_cpu under msan due to CI timeouts.
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
jax authors
518fe6656c Pickling of Sharding classes: use module level functions when deserializing.
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.

PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
Peter Hawkins
42e367af9c Fix typo in "nomsan" tag on pmap_test.
PiperOrigin-RevId: 489978468
2022-11-21 07:46:13 -08:00