Peter Hawkins
a32a7ff903
Move _src/tree_util.py into a separate Bazel target.
...
Fix a type error in api.py revealed by the split.
PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
pizzud
04def0b6ab
lazy_loader_module: Move to new internal_test_util directory.
...
Now we no longer need to mess with sys.path in lazy_loader_test.
PiperOrigin-RevId: 515674188
2023-03-10 10:29:33 -08:00
Peter Hawkins
01b00c4821
Increase sharding of shard_map test on CPU.
...
This test is timing out in CI with sanitizers enabled (asan/tsan).
PiperOrigin-RevId: 515369731
2023-03-09 10:13:26 -08:00
jax authors
59bf2061c4
Merge pull request #14565 from pizzud:deprecation-module
...
PiperOrigin-RevId: 515172435
2023-03-08 16:23:53 -08:00
jax authors
9c4db8c962
Merge pull request #14633 from mattjj:shmap-test-vmap
...
PiperOrigin-RevId: 515117185
2023-03-08 12:56:54 -08:00
pizzud
22cbf95e07
lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
...
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.
PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
jax authors
00f1abe401
Disable 2 failing jax tests.
...
PiperOrigin-RevId: 514515343
2023-03-06 13:50:40 -08:00
pizzud
ef28dcf091
lax_scipy_test: Split into three targets, take 2.
...
The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.
Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 514440062
2023-03-06 09:53:23 -08:00
pizzud
0292f5d0a6
lax_scipy_test: Revert split into three targets.
...
Somehow the spectral_dac functionality is flaky on its own when run on CPU.
PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
pizzud
09afbac6ff
lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
...
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.
This allows them to run in ASAN in more situations.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
David Pizzuto
a8f2d9a186
deprecation_module: Move to new internal_test_util directory.
...
Now we no longer need to mess with sys.path in deprecation_test.
2023-02-17 10:55:04 -08:00
pizzud
631e4ed7e0
lax_test: Create a separate module for lax-specific test utils in a new package.
...
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.
The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.
Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.
PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
43b615c0a0
Move global_device_array into its own BUILD target.
...
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
Jake VanderPlas
6608242f95
sparse_test: reduce num_generated_cases to avoid timeouts
...
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
69b8a03400
Disable some slow tests under asan.
...
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
Peter Hawkins
33bed1e520
Opt into higher matmul precision for A100 and TPU tests.
...
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Peter Hawkins
6ee67639e2
Split PyTorch interoperability tests into their own test.
...
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Ashish Shenoy
f71a55c554
Rename tensorflow core target variable to tensorflow_core
...
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
jax authors
b8d6efe22f
Merge pull request #14273 from mattjj:shard-map
...
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973
shard_map (shmap) prototype and JEP
...
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
795c14b388
Merge pull request #14252 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Yash Katariya
e5b2c5ea44
Remove the jit_pjit_api_merge disable for api_test now that it is passing
...
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
Jake VanderPlas
038798ed25
[sparse] add support for simple 1D convolutions
2023-02-01 18:53:49 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Yash Katariya
1ee21d121c
Add pjit support in jax.experimental.jet
...
PiperOrigin-RevId: 504102287
2023-01-23 15:51:47 -08:00
Skye Wanderman-Milne
953910ab45
Disable timing out sparse_test.py on msan
...
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96
Increase sharding on checkify_test.py to avoid asan timeouts
...
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
Yash Katariya
4add3b8cee
Make pjit
an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
...
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
Skye Wanderman-Milne
6d0e22eaf9
Don't run FP8 dtype test on TPU.
...
This change makes dtypes_test.py pass even when not using Bazel (e.g. with
pytest). It also improves TPU coverage when using Bazel.
PiperOrigin-RevId: 502930531
2023-01-18 11:22:17 -08:00
Yash Katariya
05e1ddd4ea
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
...
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
jax authors
8da6c89c7b
Merge pull request #13759 from sharadmv:io-callback
...
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716
Add IO callback
2023-01-17 13:55:05 -08:00
Yash Katariya
85654ceeab
Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
...
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
Yash Katariya
4601928277
Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
...
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Qiao Zhang
d203926c16
Expose fp8 in jax dtypes and mlir builder.
...
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Sharad Vikram
f729da4a36
Add shards for checkify_test on GPU
...
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
Yash Katariya
66aafb6e16
Don't take the cpp dispatch path for pjit
if it contains ordered effects just like jit
.
...
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1
Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
...
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7
Fix debug nans test after merging jit
and pjit
codepaths
...
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Yash Katariya
849af498d1
Make jaxpr_util_test work with jit/pjit merge
...
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Adam Paszke
904cd4b98d
Internal change
...
PiperOrigin-RevId: 499812920
2023-01-05 04:13:34 -08:00
Yash Katariya
c3bb26050c
Add pjit
rule to sparse_rules to support pjit
. This is done to merge the jit and pjit API.
...
PiperOrigin-RevId: 499311841
2023-01-03 14:13:19 -08:00