13846 Commits

Author SHA1 Message Date
Peter Hawkins
ebee4f4bfd Disable test variants that time out in CI.
PiperOrigin-RevId: 489214464
2022-11-17 08:14:07 -08:00
jax authors
b0d5052928 Merge pull request #13286 from jakevdp:fix-jaxlib-build
PiperOrigin-RevId: 489210938
2022-11-17 07:56:45 -08:00
Jake VanderPlas
e66fe1dff4 Update XLA commit.
Fixes build error: no such target '@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:python/MlirHloModule.cc'
2022-11-16 15:51:28 -08:00
Jake VanderPlas
e7f4fe043e jaxlib: fix mlir_hlo build rule 2022-11-16 15:42:05 -08:00
jax authors
2f3a384243 Merge pull request #13278 from google:chat_notif
PiperOrigin-RevId: 489033149
2022-11-16 14:04:53 -08:00
Skye Wanderman-Milne
0a886c34fa Include which jaxlib/libtpu version failed (latest or nightly) in TPU CI chat notification 2022-11-16 21:38:36 +00:00
jax authors
3a837c8069 Merge pull request #13281 from google:tpu_notify_main_only
PiperOrigin-RevId: 489025620
2022-11-16 13:36:49 -08:00
Skye Wanderman-Milne
b4564a2a57 TPU CI: don't notify when testing the workflow from a branch 2022-11-16 21:27:24 +00:00
jax authors
58af581e3b Merge pull request #13280 from jakevdp:sparse-test-util
PiperOrigin-RevId: 489019436
2022-11-16 13:13:04 -08:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00
Parker Schuh
fb4db5b60f Delete trailing whitespace that is blocking presubmits.
PiperOrigin-RevId: 489005596
2022-11-16 12:15:40 -08:00
Jake VanderPlas
f4932acc89 [sparse] refactor with sparse.test_util 2022-11-16 12:05:24 -08:00
jax authors
b0114121dd Merge pull request #13268 from skye:ratchet
PiperOrigin-RevId: 488992919
2022-11-16 11:28:24 -08:00
jax authors
8b46faa3ae Merge pull request #11947 from ROCmSoftwarePlatform:rocm_dlpack_support
PiperOrigin-RevId: 488985406
2022-11-16 11:00:33 -08:00
Skye Wanderman-Milne
8bed9bac81 Update Github Actions workflows using Ratchet
https://opensource.google/documentation/reference/github/services#actions
mandates using a specific commit for non-Google actions in workflow
files. I used https://github.com/sethvargo/ratchet to update all our
workflow files. Example command: `ratchet pin cloud-tpu-ci-nightly.yml`

Ratchet appears to also auto-format the YAML files. It makes the diff
confusing but I'm ok with the final result.
2022-11-16 18:45:59 +00:00
jax authors
bbc3c6aa89 Merge pull request #13266 from skye:tensorboard_docs
PiperOrigin-RevId: 488975887
2022-11-16 10:27:15 -08:00
jax authors
9259c9f681 Merge pull request #13191 from jakevdp:bcoo-tests
PiperOrigin-RevId: 488974571
2022-11-16 10:20:46 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Yash Katariya
0f9d237e57 Remove the newlines from the error message because users seem to miss the part after the newline.
PiperOrigin-RevId: 488963130
2022-11-16 09:42:05 -08:00
Krishna Haridasan
ac2af539d3 Expand XlaExecutable.cost_analysis to call Executable.cost_analysis
in case the backend does not implement the "client" attribute on xla_executable.

PiperOrigin-RevId: 488962373
2022-11-16 09:35:18 -08:00
jax authors
e4806e00bb Merge pull request #13270 from patrick-kidger:array-module-fix
PiperOrigin-RevId: 488956649
2022-11-16 09:11:58 -08:00
Peter Hawkins
51c69ac594 Tag several tests as optonly to prevent test timeouts in debug mode CI builds.
PiperOrigin-RevId: 488950972
2022-11-16 08:50:23 -08:00
jax authors
f2bd1afb7e Change repr on NamedSharding to match variable names.
PiperOrigin-RevId: 488950019
2022-11-16 08:43:24 -08:00
Yash Katariya
9799d5b139 Add the jax.Array change to the changelog.
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Peter Hawkins
0548c2d23b Disable sanitizer builds that are timing out or that are incompatible with test targets.
PiperOrigin-RevId: 488919571
2022-11-16 06:00:37 -08:00
jax authors
5675b4572e Merge pull request #13272 from gnecula:tf_fix2
PiperOrigin-RevId: 488892036
2022-11-16 03:21:30 -08:00
George Necula
c7031ff369 [jax2tf] Disable experimental_native_lowering polymorphism tests when not jax_dynamic_shapes 2022-11-16 11:58:23 +01:00
Patrick Kidger
c6cf195cd6 Fix less-nice type annotations for e.g. List[Array] 2022-11-15 22:56:15 -08:00
jax authors
1ce985fecb Merge pull request #13237 from jakevdp:bcoo-reshape
PiperOrigin-RevId: 488814859
2022-11-15 18:46:23 -08:00
jax authors
fba279a1c6 Merge pull request #13242 from google:array_tutorial
PiperOrigin-RevId: 488807156
2022-11-15 18:01:03 -08:00
Skye Wanderman-Milne
66f3e0da9c Update TensorBoard install instructions for profiling 2022-11-16 01:02:45 +00:00
yashkatariya
aca7e4ade2 jax.Array tutorial 2022-11-15 16:49:17 -08:00
Yash Katariya
ea930e1d8d Default jax.Array to True globally. See https://jax.readthedocs.io/en/latest/jax_array_migration.html for migration to jax.Array.
PiperOrigin-RevId: 488764287
2022-11-15 14:50:05 -08:00
jax authors
d9e4058c4b Merge pull request #13263 from jakevdp:bcoo-iter
PiperOrigin-RevId: 488763629
2022-11-15 14:43:40 -08:00
Peter Hawkins
99e1c3dd66 [JAX] Opt into high precision matrix multiplications in JAX tests that fail on A100.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.

Fixes https://github.com/google/jax/issues/12008

PiperOrigin-RevId: 488749199
2022-11-15 13:50:21 -08:00
jax authors
31c8bfe2a8 Merge pull request #13264 from jakevdp:fix-pre-commit
PiperOrigin-RevId: 488746356
2022-11-15 13:39:00 -08:00
Jake VanderPlas
3d5eadde26 CI: update flake8 pre-commit URL 2022-11-15 13:32:48 -08:00
Jake VanderPlas
c85230c2c6 [sparse] support dense dimensions in bcoo_reshape 2022-11-15 13:19:44 -08:00
Jake VanderPlas
c8ee485de8 [sparse] implement iter() of BCOO 2022-11-15 13:17:06 -08:00
jax authors
726b2bc2ee Add JAX monitoring library that instruments code via events.
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Yash Katariya
eca12411e7 Disable some tests with jax.Array that are failing in OSS due to using minimum_jaxlib_version. I will bump the version again this week.
PiperOrigin-RevId: 488708528
2022-11-15 11:10:29 -08:00
Adam Paszke
d742e6a410 Transpose all_gather to reduce_scatter
Also, add support for AD and batching of reduce_scatter (with its transpose being all_gather again).

PiperOrigin-RevId: 488706478
2022-11-15 11:03:22 -08:00
Yash Katariya
8c42edfec1 Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
George Necula
e4751d4b02 [jax2tf] Enable StableHLO in jax2tf native lowering.
PiperOrigin-RevId: 488654050
2022-11-15 07:42:49 -08:00
jax authors
20f092c916 As explained in the JAX FAQ, jax.numpy.where has suprising behavior when
its gradient is taken and one of the inputs is NaN.  This CL adds
a short description of the behavior to the jax.numpy.where docs,
which is the logical place that users would look for it.

PiperOrigin-RevId: 488654036
2022-11-15 07:42:35 -08:00
Lena Martens
3116ed52a9 Checkify: fix nan check when primitive has multiple results.
PiperOrigin-RevId: 488653856
2022-11-15 07:35:54 -08:00
jax authors
108bc83520 Merge pull request #13231 from marcvanzee:patch-5
PiperOrigin-RevId: 488633221
2022-11-15 05:45:06 -08:00
jax authors
d4b061bb28 Merge pull request #13235 from LenaMartens:nan-set
PiperOrigin-RevId: 488612854
2022-11-15 03:36:42 -08:00
Yash Katariya
a683186570 Use the 11/09 libtpu build for jaxlib release since that passes all the tests.
PiperOrigin-RevId: 488543322
jax-v0.3.25 jaxlib-v0.3.25 jax-v0.3.25-rc3 0.3.25
2022-11-14 20:37:41 -08:00