49 Commits

Author SHA1 Message Date
Yash Katariya
f905d989c1 Make eager pmap tests pass with Array. Also add a slow path for Array in pmap similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
PiperOrigin-RevId: 468843310
2022-08-19 21:37:22 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Yash Katariya
2231b0c054 Add array test counterparts to dtypes_test.
PiperOrigin-RevId: 468568788
2022-08-18 15:58:39 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Sharad Vikram
8068e4638c Re-bump shard count for pmap_test
PiperOrigin-RevId: 468239588
2022-08-17 10:46:19 -07:00
jax authors
7721579700 Internal change
PiperOrigin-RevId: 468068879
2022-08-16 17:43:03 -07:00
Sharad Vikram
6ae46c3d69 Bump pmap_test size to handle new eager tests
PiperOrigin-RevId: 467967021
2022-08-16 10:45:50 -07:00
Peter Hawkins
cd62dbffb3 Increase Bazel sharding of scipy_signal_test on CPU. 2022-08-12 15:14:35 +00:00
Peter Hawkins
3a693e98c0 Increase Bazel sharding of long-running CPU tests. 2022-08-12 13:21:40 +00:00
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Sharad Vikram
2d72dc899b Enable receives in TPU callbacks and add tests
PiperOrigin-RevId: 466103306
2022-08-08 11:42:16 -07:00
Peter Hawkins
efe1b6b44c Disable tsan on linalg_test on TPU.
The test is too slow under tsan and times out in CI.

PiperOrigin-RevId: 465570229
2022-08-05 08:33:42 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
jax authors
0a8ca1982c Merge pull request #11721 from sudhakarsingh27:main
PiperOrigin-RevId: 465381834
2022-08-04 12:52:16 -07:00
jax authors
01819257f6 Merge pull request #11701 from sharadmv:state
PiperOrigin-RevId: 464658336
2022-08-01 17:05:43 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
Peter Hawkins
9e6254e058 Increase shard counts for TPU tests in an attempt to fix CI timeouts under asan.
PiperOrigin-RevId: 463830139
2022-07-28 07:14:36 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Adam Paszke
117da44712 Internal change
PiperOrigin-RevId: 462110048
2022-07-20 04:31:21 -07:00
George Necula
777c129dfb [dynamic-shapes] Split dynamic_api_test.py
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
jax authors
3eff9d11d2 Internal change
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Peter Hawkins
5a7bedca37 Increase shard_count for sparse_test_gpu to 20.
1918d39765 updated the wrong test!

This test is close to the timeout in the GPU CI and flakes sometimes.

PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
Peter Hawkins
1918d39765 Increase number of shards for GPU sparse_test to 20. 2022-07-07 21:14:25 -04:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Peter Hawkins
95e79332c0 Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).

PiperOrigin-RevId: 459320212
2022-07-06 12:51:43 -07:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Roy Frostig
28828772b8 remove unuzed bazel build rules, including bazel test definitions 2018-12-16 10:42:50 -08:00
Roy Frostig
a97f820961 update examples/BUILD 2018-12-15 15:54:31 -08:00
Peter Hawkins
3561b432c2 Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.

Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
Roy Frostig
cfc37eb0be add a test of the resnet50 example 2018-12-12 16:09:20 -08:00
Roy Frostig
9c69014318 Only build-depend on libjax optionally (via bazel CLI flag) in the generated bazel test targets. 2018-12-12 18:01:44 -05:00
Dougal Maclaurin
c3374a9d5f Added build rule for generated_fun_test (formerly quickish_check) 2018-12-06 17:04:00 -05:00
Dougal Maclaurin
29113dd606 Made tests runnable with bazel 2018-12-06 17:00:47 -05:00
Roy Frostig
f5b051b431 Double gpu test shards
PiperOrigin-RevId: 223647959
2018-12-02 11:50:46 -08:00
Roy Frostig
20878c76f4 source sync
PiperOrigin-RevId: 223530503
2018-12-02 11:50:39 -08:00
Peter Hawkins
6361b784a8 source sync
PiperOrigin-RevId: 222456068
2018-11-21 20:22:56 -08:00
Matthew Johnson
25fb9b421d source sync
PiperOrigin-RevId: 222170151
2018-11-21 20:22:33 -08:00
Roy Frostig
a3619ca89d source sync
PiperOrigin-RevId: 222153576
2018-11-21 20:22:30 -08:00
Matthew Johnson
50038c07c8 fix build file issues 2018-11-19 20:18:31 -08:00
Roy Frostig
99f98f8e8c source sync 2018-11-19 13:50:57 -08:00
Roy Frostig
da2d53ad33 source sync 2018-11-19 13:29:47 -08:00
Matthew Johnson
9ae0f3a610 split BUILD file, move up license files 2018-11-18 15:43:09 -08:00