Kuangyuan Chen
c0ec3b33e6
Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
...
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Yash Katariya
d8cbb29d14
OpSharding doesn't have __eq__
defined on it. Don't check sharding equality using opsharding until it does support that.
...
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Parker Schuh
d8f0099f68
_mlirTransforms merged into _mlirRegisterEverything.
...
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Yash Katariya
ad67d825fe
Add a faster __eq__ check for Mesh. When the id
of self and other is the same, there is no need to compare the devices which can be slow when there are 1000s of devices.
...
PiperOrigin-RevId: 462230016
2022-07-20 14:25:41 -07:00
Yash Katariya
026636951a
Add lru_cache
and use it instead of util.cache()
in places where tracing user code is not required.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 462212010
2022-07-20 13:05:20 -07:00
jax authors
ffe67c1042
Merge pull request #11563 from jakevdp:upstream-ci
...
PiperOrigin-RevId: 462200130
2022-07-20 12:09:08 -07:00
jax authors
5370f59036
Merge pull request #11566 from jakevdp:missing-f
...
PiperOrigin-RevId: 462187015
2022-07-20 11:16:22 -07:00
jax authors
d0162cd37e
Merge pull request #11533 from ROCmSoftwarePlatform:rocm_disable_lobpcg_test
...
PiperOrigin-RevId: 462182051
2022-07-20 10:56:54 -07:00
Jake VanderPlas
114b03670c
Add missing f-string marker
2022-07-20 10:48:07 -07:00
Jake VanderPlas
993196c451
CI: make parse_logs more robust to errors
2022-07-20 10:33:32 -07:00
Adam Paszke
117da44712
Internal change
...
PiperOrigin-RevId: 462110048
2022-07-20 04:31:21 -07:00
jax authors
7f1813c5e3
Merge pull request #11539 from gnecula:ds_reshape
...
PiperOrigin-RevId: 462061742
2022-07-19 23:13:03 -07:00
jax authors
9ee6cacdc8
Merge pull request #11540 from gnecula:ds_check_flag
...
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
jax authors
02c3d6bf2c
Merge pull request #11553 from mattjj:readme-windows-update
...
PiperOrigin-RevId: 462022337
2022-07-19 18:19:37 -07:00
jax authors
7e5bc2977b
Merge pull request #11552 from mattjj:mhlo-bint-progress
...
PiperOrigin-RevId: 462015062
2022-07-19 17:32:00 -07:00
jax authors
5a9b8490eb
Merge pull request #11554 from google:skye-patch-1
...
PiperOrigin-RevId: 462013784
2022-07-19 17:24:56 -07:00
Matthew Johnson
3ab8637a73
Update README discussion of Windows support:
...
* right next to the pip installation instructions, mention they don't work for Windows;
* add a link to #5795 for an unofficial discussion of Windows native support
2022-07-19 17:12:40 -07:00
Skye Wanderman-Milne
186a4f83e3
Update libtpu version for 0.3.15 release
2022-07-19 17:11:43 -07:00
jax authors
18541e2efa
Merge pull request #11542 from mattjj:remove-resnet50-example
...
PiperOrigin-RevId: 462008290
2022-07-19 16:56:56 -07:00
Matthew Johnson
7cb5c2447e
[dynamic-shapes] fix minor bint bugs
...
Co-authored-by: Eugene Burmako <burmako@google.com>
2022-07-19 16:38:40 -07:00
jax authors
6ea9e4d4dd
Merge pull request #11546 from jakevdp:fix-scipy-sym-pos
...
PiperOrigin-RevId: 461976166
2022-07-19 14:26:51 -07:00
Jake VanderPlas
9090dd179d
jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0
2022-07-19 13:57:49 -07:00
Yash Katariya
9f914a93d6
Replace input sharding_specs
with in_shardings
in InputsHandler
...
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
jax authors
226cef08bf
Merge pull request #11544 from jakevdp:conv-general-pad
...
PiperOrigin-RevId: 461937001
2022-07-19 11:46:08 -07:00
Jake VanderPlas
489596c0e2
lax.conv_general_dilated: validate negative paddings
2022-07-19 11:15:18 -07:00
George Necula
2106d65561
[dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
...
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
jax authors
ac731bb4cc
Merge pull request #11518 from ROCmSoftwarePlatform:rocm_disable_dlpack_unit_tests
...
PiperOrigin-RevId: 461904181
2022-07-19 09:48:36 -07:00
jax authors
97fe2262d3
Merge pull request #11526 from atgctg:patch-2
...
PiperOrigin-RevId: 461903614
2022-07-19 09:42:30 -07:00
Matthew Johnson
e350894371
remove resnet50 example
2022-07-19 09:40:39 -07:00
jax authors
388733b533
Merge pull request #11460 from jakevdp:profiler-deprecation
...
PiperOrigin-RevId: 461903612
2022-07-19 09:36:36 -07:00
Benjamin Kramer
2c72858928
Integrate LLVM at llvm/llvm-project@8aff88fd3a
...
Updates LLVM usage to match
[8aff88fd3a5f](https://github.com/llvm/llvm-project/commit/8aff88fd3a5f )
PiperOrigin-RevId: 461889195
2022-07-19 08:31:24 -07:00
Jake VanderPlas
2543542fa8
jax.profiler: remove deprecated functions
2022-07-19 08:13:44 -07:00
George Necula
c45fe49821
[dynamic-shapes] Add typechecking rule for reshape
2022-07-19 15:10:14 +02:00
jax authors
12ce369c94
Merge pull request #11506 from gnecula:jax2tf_call_module
...
PiperOrigin-RevId: 461834766
2022-07-19 02:42:57 -07:00
George Necula
ee50140701
[jax2tf] A new experimental version with JAX native lowering.
...
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.
This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
2022-07-19 10:50:04 +02:00
Yash Katariya
ea627b807b
Replace out_specs
with out_shardings
and remove out_indices
in ResultsHandler.
...
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
e1fdd57ff3
Merge pull request #11532 from jakevdp:fix-upstream-ci
...
PiperOrigin-RevId: 461706572
2022-07-18 13:57:56 -07:00
jax authors
b0805a8a31
Fixes the JAX implementation of CELU returning NaN gradients for input
...
values >= 88.7229.
When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.
PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Rohit Santhanam
c4b37ad8a1
[ROCm] Disable lobpcg unit test for ROCm until performance issue is resolved.
2022-07-18 18:12:19 +00:00
Jake VanderPlas
2eaa44e8d8
[CI] upstream-dev: run all tests
2022-07-18 11:06:56 -07:00
jax authors
d98d5ddce5
[JAX] Add jax_unique_mhlo_module_names
flag to control if MHLO should be made unique.
...
Some clients of JAX expect module names to not be altered so that they can cache XLA compilations.
PiperOrigin-RevId: 461648129
2022-07-18 10:05:44 -07:00
atgctg
1d5fbeadd2
Fix typo
2022-07-18 11:23:21 +02:00
jax authors
ae4aee762a
[jax2tf] Fix conv1d padding; it's already normalized before the _pad_spatial_dims call. Enable non-XLA tests of conv1d.
...
PiperOrigin-RevId: 461556553
2022-07-18 01:28:18 -07:00
Rohit Santhanam
235ea7059c
[ROCm] Disable new array_interoperability dlpack tests.
2022-07-17 04:48:11 +00:00
jax authors
a08a1f284a
Merge pull request #11504 from gnecula:shape_poly_conv2
...
PiperOrigin-RevId: 461368469
2022-07-16 11:33:59 -07:00
jax authors
a2371925fd
Merge pull request #11513 from jakevdp:from-dlpack
...
PiperOrigin-RevId: 461274756
2022-07-15 17:46:32 -07:00
Jake VanderPlas
2f4c485a54
Add dlpack support to device_array and jax.numpy
2022-07-15 17:31:11 -07:00
jax authors
7d7aa467f4
Merge pull request #11514 from jakevdp:stdout-atty
...
PiperOrigin-RevId: 461266411
2022-07-15 16:53:33 -07:00
jax authors
0da5657ac7
Merge pull request #11507 from jakevdp:tree-util-warning-level
...
PiperOrigin-RevId: 461266381
2022-07-15 16:48:13 -07:00
Yash Katariya
90687cc1ff
Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation
except of AUTO
and UNSPECIFIED
(see below for these 2).
...
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00