7907 Commits

Author SHA1 Message Date
jax authors
e121e811ab Merge pull request #11536 from sharadmv:colab-debugger
PiperOrigin-RevId: 462665740
2022-07-22 11:28:02 -07:00
jax authors
0b6657e471 Merge pull request #11556 from RuffaloLavoisier:tYpO
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
4870710891 Enable debugging callbacks with pjit on TPU
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
jax authors
8a67734e7b Merge pull request #11579 from sharadmv:fix-effects
PiperOrigin-RevId: 462478510
2022-07-21 15:02:46 -07:00
jax authors
7f0b9179f2 Merge pull request #11575 from gnecula:ds_progress
PiperOrigin-RevId: 462475336
2022-07-21 14:48:24 -07:00
jax authors
24134ec2a5 Merge pull request #11425 from pschuh:pjit-bugfix
PiperOrigin-RevId: 462469178
2022-07-21 14:20:00 -07:00
jax authors
540ee56ff2 Merge pull request #11576 from jakevdp:searchsorted-alt
PiperOrigin-RevId: 462461853
2022-07-21 13:47:43 -07:00
Sharad Vikram
d6c172d53e Fix PE not allowing double JIT-ted effectful functions 2022-07-21 11:55:48 -07:00
jax authors
f6c168276b Merge pull request #11578 from jakevdp:wraps-mod
PiperOrigin-RevId: 462437654
2022-07-21 11:50:47 -07:00
Jake VanderPlas
9769a0accf DOC: ensure that _wraps() generates correct links to wrapped functions 2022-07-21 11:12:35 -07:00
jax authors
1e05a1cfbc Merge pull request #10816 from mattjj:remove-old-pjit-comment
PiperOrigin-RevId: 462411602
2022-07-21 10:01:57 -07:00
George Necula
6c9d2a0b54 [jax2tf] Raise errors for experimental_native_lowering and custom_call
Raise explicit error when the experimental_native_lowering encounters
a mhlo.custom_call. This would lead to failure when trying to run in TF.
2022-07-21 19:58:05 +03:00
Jake VanderPlas
10411bfeae jnp.searchsorted: add optional method argument to control implementation 2022-07-21 09:40:18 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
jax authors
be6db2e619 Merge pull request #10775 from pschuh:mlir-caching
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
Parker Schuh
6c4da65af4 Add treedef_is_strict_leaf to fix _prefix_error's semantics.
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.

Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
2022-07-20 17:02:59 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Yash Katariya
d8cbb29d14 OpSharding doesn't have __eq__ defined on it. Don't check sharding equality using opsharding until it does support that.
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Yash Katariya
ad67d825fe Add a faster __eq__ check for Mesh. When the id of self and other is the same, there is no need to compare the devices which can be slow when there are 1000s of devices.
PiperOrigin-RevId: 462230016
2022-07-20 14:25:41 -07:00
Yash Katariya
026636951a Add lru_cache and use it instead of util.cache() in places where tracing user code is not required.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 462212010
2022-07-20 13:05:20 -07:00
Jake VanderPlas
114b03670c Add missing f-string marker 2022-07-20 10:48:07 -07:00
RuffaloLavoisier
9f770425ac Correct spelling on word 2022-07-20 18:57:12 +09:00
jax authors
7f1813c5e3 Merge pull request #11539 from gnecula:ds_reshape
PiperOrigin-RevId: 462061742
2022-07-19 23:13:03 -07:00
jax authors
9ee6cacdc8 Merge pull request #11540 from gnecula:ds_check_flag
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
jax authors
7e5bc2977b Merge pull request #11552 from mattjj:mhlo-bint-progress
PiperOrigin-RevId: 462015062
2022-07-19 17:32:00 -07:00
jax authors
18541e2efa Merge pull request #11542 from mattjj:remove-resnet50-example
PiperOrigin-RevId: 462008290
2022-07-19 16:56:56 -07:00
Matthew Johnson
7cb5c2447e [dynamic-shapes] fix minor bint bugs
Co-authored-by: Eugene Burmako <burmako@google.com>
2022-07-19 16:38:40 -07:00
jax authors
6ea9e4d4dd Merge pull request #11546 from jakevdp:fix-scipy-sym-pos
PiperOrigin-RevId: 461976166
2022-07-19 14:26:51 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Yash Katariya
9f914a93d6 Replace input sharding_specs with in_shardings in InputsHandler
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
jax authors
226cef08bf Merge pull request #11544 from jakevdp:conv-general-pad
PiperOrigin-RevId: 461937001
2022-07-19 11:46:08 -07:00
Jake VanderPlas
489596c0e2 lax.conv_general_dilated: validate negative paddings 2022-07-19 11:15:18 -07:00
George Necula
2106d65561 [dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
Jake VanderPlas
2543542fa8 jax.profiler: remove deprecated functions 2022-07-19 08:13:44 -07:00
George Necula
c45fe49821 [dynamic-shapes] Add typechecking rule for reshape 2022-07-19 15:10:14 +02:00
George Necula
ee50140701 [jax2tf] A new experimental version with JAX native lowering.
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.

This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
2022-07-19 10:50:04 +02:00
Sharad Vikram
09fd173a3e Add colab debugger 2022-07-18 22:03:27 -07:00
Yash Katariya
ea627b807b Replace out_specs with out_shardings and remove out_indices in ResultsHandler.
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
Parker Schuh
704f125c88 Add caching to trace_to_subjaxpr_dynamic2.
This allows the MLIR lowering code to cache call lowerings.

example output:

```
module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}
```

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
jax authors
b0805a8a31 Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
jax authors
d98d5ddce5 [JAX] Add jax_unique_mhlo_module_names flag to control if MHLO should be made unique.
Some clients of JAX expect module names to not be altered so that they can cache XLA compilations.

PiperOrigin-RevId: 461648129
2022-07-18 10:05:44 -07:00
jax authors
ae4aee762a [jax2tf] Fix conv1d padding; it's already normalized before the _pad_spatial_dims call. Enable non-XLA tests of conv1d.
PiperOrigin-RevId: 461556553
2022-07-18 01:28:18 -07:00
jax authors
a08a1f284a Merge pull request #11504 from gnecula:shape_poly_conv2
PiperOrigin-RevId: 461368469
2022-07-16 11:33:59 -07:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
jax authors
7d7aa467f4 Merge pull request #11514 from jakevdp:stdout-atty
PiperOrigin-RevId: 461266411
2022-07-15 16:53:33 -07:00
jax authors
0da5657ac7 Merge pull request #11507 from jakevdp:tree-util-warning-level
PiperOrigin-RevId: 461266381
2022-07-15 16:48:13 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
Jake VanderPlas
b41f33b0d7 pretty_printing: handle case where stdout is patched by a logger 2022-07-15 14:50:17 -07:00
Jake VanderPlas
c1549a0a16 [sparse] make sparse objects compatible with jax.jit.lower() 2022-07-15 09:58:31 -07:00
Jake VanderPlas
6907dfad00 tree_util: fix warning category and stacklevel 2022-07-15 09:24:22 -07:00