13841 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
jax authors
3a837c8069 Merge pull request #13281 from google:tpu_notify_main_only
PiperOrigin-RevId: 489025620
2022-11-16 13:36:49 -08:00
Skye Wanderman-Milne
b4564a2a57 TPU CI: don't notify when testing the workflow from a branch 2022-11-16 21:27:24 +00:00
jax authors
58af581e3b Merge pull request #13280 from jakevdp:sparse-test-util
PiperOrigin-RevId: 489019436
2022-11-16 13:13:04 -08:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00
Parker Schuh
fb4db5b60f Delete trailing whitespace that is blocking presubmits.
PiperOrigin-RevId: 489005596
2022-11-16 12:15:40 -08:00
Jake VanderPlas
f4932acc89 [sparse] refactor with sparse.test_util 2022-11-16 12:05:24 -08:00
jax authors
b0114121dd Merge pull request #13268 from skye:ratchet
PiperOrigin-RevId: 488992919
2022-11-16 11:28:24 -08:00
jax authors
8b46faa3ae Merge pull request #11947 from ROCmSoftwarePlatform:rocm_dlpack_support
PiperOrigin-RevId: 488985406
2022-11-16 11:00:33 -08:00
Skye Wanderman-Milne
8bed9bac81 Update Github Actions workflows using Ratchet
https://opensource.google/documentation/reference/github/services#actions
mandates using a specific commit for non-Google actions in workflow
files. I used https://github.com/sethvargo/ratchet to update all our
workflow files. Example command: `ratchet pin cloud-tpu-ci-nightly.yml`

Ratchet appears to also auto-format the YAML files. It makes the diff
confusing but I'm ok with the final result.
2022-11-16 18:45:59 +00:00
jax authors
bbc3c6aa89 Merge pull request #13266 from skye:tensorboard_docs
PiperOrigin-RevId: 488975887
2022-11-16 10:27:15 -08:00
jax authors
9259c9f681 Merge pull request #13191 from jakevdp:bcoo-tests
PiperOrigin-RevId: 488974571
2022-11-16 10:20:46 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Yash Katariya
0f9d237e57 Remove the newlines from the error message because users seem to miss the part after the newline.
PiperOrigin-RevId: 488963130
2022-11-16 09:42:05 -08:00
Krishna Haridasan
ac2af539d3 Expand XlaExecutable.cost_analysis to call Executable.cost_analysis
in case the backend does not implement the "client" attribute on xla_executable.

PiperOrigin-RevId: 488962373
2022-11-16 09:35:18 -08:00
jax authors
e4806e00bb Merge pull request #13270 from patrick-kidger:array-module-fix
PiperOrigin-RevId: 488956649
2022-11-16 09:11:58 -08:00
Peter Hawkins
51c69ac594 Tag several tests as optonly to prevent test timeouts in debug mode CI builds.
PiperOrigin-RevId: 488950972
2022-11-16 08:50:23 -08:00
jax authors
f2bd1afb7e Change repr on NamedSharding to match variable names.
PiperOrigin-RevId: 488950019
2022-11-16 08:43:24 -08:00
Yash Katariya
9799d5b139 Add the jax.Array change to the changelog.
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Peter Hawkins
0548c2d23b Disable sanitizer builds that are timing out or that are incompatible with test targets.
PiperOrigin-RevId: 488919571
2022-11-16 06:00:37 -08:00
jax authors
5675b4572e Merge pull request #13272 from gnecula:tf_fix2
PiperOrigin-RevId: 488892036
2022-11-16 03:21:30 -08:00
George Necula
c7031ff369 [jax2tf] Disable experimental_native_lowering polymorphism tests when not jax_dynamic_shapes 2022-11-16 11:58:23 +01:00
Patrick Kidger
c6cf195cd6 Fix less-nice type annotations for e.g. List[Array] 2022-11-15 22:56:15 -08:00
jax authors
1ce985fecb Merge pull request #13237 from jakevdp:bcoo-reshape
PiperOrigin-RevId: 488814859
2022-11-15 18:46:23 -08:00
jax authors
fba279a1c6 Merge pull request #13242 from google:array_tutorial
PiperOrigin-RevId: 488807156
2022-11-15 18:01:03 -08:00
Skye Wanderman-Milne
66f3e0da9c Update TensorBoard install instructions for profiling 2022-11-16 01:02:45 +00:00
yashkatariya
aca7e4ade2 jax.Array tutorial 2022-11-15 16:49:17 -08:00
Yash Katariya
ea930e1d8d Default jax.Array to True globally. See https://jax.readthedocs.io/en/latest/jax_array_migration.html for migration to jax.Array.
PiperOrigin-RevId: 488764287
2022-11-15 14:50:05 -08:00
jax authors
d9e4058c4b Merge pull request #13263 from jakevdp:bcoo-iter
PiperOrigin-RevId: 488763629
2022-11-15 14:43:40 -08:00
Peter Hawkins
99e1c3dd66 [JAX] Opt into high precision matrix multiplications in JAX tests that fail on A100.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.

Fixes https://github.com/google/jax/issues/12008

PiperOrigin-RevId: 488749199
2022-11-15 13:50:21 -08:00
jax authors
31c8bfe2a8 Merge pull request #13264 from jakevdp:fix-pre-commit
PiperOrigin-RevId: 488746356
2022-11-15 13:39:00 -08:00
Jake VanderPlas
3d5eadde26 CI: update flake8 pre-commit URL 2022-11-15 13:32:48 -08:00
Jake VanderPlas
c85230c2c6 [sparse] support dense dimensions in bcoo_reshape 2022-11-15 13:19:44 -08:00
Jake VanderPlas
c8ee485de8 [sparse] implement iter() of BCOO 2022-11-15 13:17:06 -08:00
jax authors
726b2bc2ee Add JAX monitoring library that instruments code via events.
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Yash Katariya
eca12411e7 Disable some tests with jax.Array that are failing in OSS due to using minimum_jaxlib_version. I will bump the version again this week.
PiperOrigin-RevId: 488708528
2022-11-15 11:10:29 -08:00
Adam Paszke
d742e6a410 Transpose all_gather to reduce_scatter
Also, add support for AD and batching of reduce_scatter (with its transpose being all_gather again).

PiperOrigin-RevId: 488706478
2022-11-15 11:03:22 -08:00
Yash Katariya
8c42edfec1 Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
George Necula
e4751d4b02 [jax2tf] Enable StableHLO in jax2tf native lowering.
PiperOrigin-RevId: 488654050
2022-11-15 07:42:49 -08:00
jax authors
20f092c916 As explained in the JAX FAQ, jax.numpy.where has suprising behavior when
its gradient is taken and one of the inputs is NaN.  This CL adds
a short description of the behavior to the jax.numpy.where docs,
which is the logical place that users would look for it.

PiperOrigin-RevId: 488654036
2022-11-15 07:42:35 -08:00
Lena Martens
3116ed52a9 Checkify: fix nan check when primitive has multiple results.
PiperOrigin-RevId: 488653856
2022-11-15 07:35:54 -08:00
jax authors
108bc83520 Merge pull request #13231 from marcvanzee:patch-5
PiperOrigin-RevId: 488633221
2022-11-15 05:45:06 -08:00
jax authors
d4b061bb28 Merge pull request #13235 from LenaMartens:nan-set
PiperOrigin-RevId: 488612854
2022-11-15 03:36:42 -08:00
Yash Katariya
a683186570 Use the 11/09 libtpu build for jaxlib release since that passes all the tests.
PiperOrigin-RevId: 488543322
jax-v0.3.25 jaxlib-v0.3.25 jax-v0.3.25-rc3 0.3.25
2022-11-14 20:37:41 -08:00
Parker Schuh
7635df84f0 Remove custom potrf kernels in favor of native XLA cholesky support.
PiperOrigin-RevId: 488525158
2022-11-14 18:45:25 -08:00
Yash Katariya
f36084acd3 Update the values for jaxlib release (again)
PiperOrigin-RevId: 488522992
jax-v0.3.25-rc1
2022-11-14 18:31:08 -08:00
Yash Katariya
0da02dd41c Update the values needed for a jaxlib release
PiperOrigin-RevId: 488508360
2022-11-14 17:08:59 -08:00
jax authors
465d4ece36 Merge pull request #13243 from google:yashk2810-patch-19
PiperOrigin-RevId: 488506173
2022-11-14 16:58:33 -08:00
Yash Katariya
50f7512a6e
Update WORKSPACE for jaxlib release 2022-11-14 16:37:09 -08:00