4518 Commits

Author SHA1 Message Date
jax authors
0b6657e471 Merge pull request #11556 from RuffaloLavoisier:tYpO
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
4870710891 Enable debugging callbacks with pjit on TPU
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
jax authors
8a67734e7b Merge pull request #11579 from sharadmv:fix-effects
PiperOrigin-RevId: 462478510
2022-07-21 15:02:46 -07:00
jax authors
24134ec2a5 Merge pull request #11425 from pschuh:pjit-bugfix
PiperOrigin-RevId: 462469178
2022-07-21 14:20:00 -07:00
jax authors
540ee56ff2 Merge pull request #11576 from jakevdp:searchsorted-alt
PiperOrigin-RevId: 462461853
2022-07-21 13:47:43 -07:00
Sharad Vikram
d6c172d53e Fix PE not allowing double JIT-ted effectful functions 2022-07-21 11:55:48 -07:00
jax authors
a4e754849e Merge pull request #11543 from nvcastet:fix_multigpu_test
PiperOrigin-RevId: 462418103
2022-07-21 10:27:57 -07:00
Jake VanderPlas
10411bfeae jnp.searchsorted: add optional method argument to control implementation 2022-07-21 09:40:18 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
jax authors
be6db2e619 Merge pull request #10775 from pschuh:mlir-caching
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
Parker Schuh
6c4da65af4 Add treedef_is_strict_leaf to fix _prefix_error's semantics.
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.

Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
2022-07-20 17:02:59 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Yash Katariya
d8cbb29d14 OpSharding doesn't have __eq__ defined on it. Don't check sharding equality using opsharding until it does support that.
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
jax authors
d0162cd37e Merge pull request #11533 from ROCmSoftwarePlatform:rocm_disable_lobpcg_test
PiperOrigin-RevId: 462182051
2022-07-20 10:56:54 -07:00
Adam Paszke
117da44712 Internal change
PiperOrigin-RevId: 462110048
2022-07-20 04:31:21 -07:00
RuffaloLavoisier
9f770425ac Correct spelling on word 2022-07-20 18:57:12 +09:00
jax authors
7f1813c5e3 Merge pull request #11539 from gnecula:ds_reshape
PiperOrigin-RevId: 462061742
2022-07-19 23:13:03 -07:00
Matthew Johnson
7cb5c2447e [dynamic-shapes] fix minor bint bugs
Co-authored-by: Eugene Burmako <burmako@google.com>
2022-07-19 16:38:40 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Nicolas Castet
7589c6d7f0 Fix MultiProcessGpuTest test
Since MultiProcessGpuTest was using 'shell=True', only the first element
of the args was executed (i.e. python). Therefore the spawn processes
never executed jax code.
Fix the test and make sure 'jax.distributed' initialize by checking
jax.device_count().
2022-07-19 11:13:32 -05:00
George Necula
c45fe49821 [dynamic-shapes] Add typechecking rule for reshape 2022-07-19 15:10:14 +02:00
Parker Schuh
704f125c88 Add caching to trace_to_subjaxpr_dynamic2.
This allows the MLIR lowering code to cache call lowerings.

example output:

```
module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}
```

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
Rohit Santhanam
c4b37ad8a1 [ROCm] Disable lobpcg unit test for ROCm until performance issue is resolved. 2022-07-18 18:12:19 +00:00
Rohit Santhanam
235ea7059c [ROCm] Disable new array_interoperability dlpack tests. 2022-07-17 04:48:11 +00:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
Jake VanderPlas
c1549a0a16 [sparse] make sparse objects compatible with jax.jit.lower() 2022-07-15 09:58:31 -07:00
Tom Hennigan
10720258ea Reduce the verbosity of treedef printing for custom nodes.
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.

Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:

    >>> params = { .. some params .. }
    >>> f = jax.jit(..).lower(params).compile()
    >>> f(params)  # fine
    >>> params['some_new_thing'] = something
    >>> f(params)
    TypeError: function compiled for {treedef}, called with {treedef}.

PiperOrigin-RevId: 461190971
2022-07-15 07:14:28 -07:00
George Necula
777c129dfb [dynamic-shapes] Split dynamic_api_test.py
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
Jake VanderPlas
0f14943524 lax_numpy_test: make compatible with numpy 1.24-dev 2022-07-14 14:35:10 -07:00
jax authors
4d1c6dfefa Merge pull request #11469 from jakevdp:fix-rem-jvp
PiperOrigin-RevId: 460517781
2022-07-12 11:53:27 -07:00
Jake VanderPlas
daf6e3b065 BUG: fix jvp rule for lax.rem 2022-07-12 09:50:42 -07:00
jax authors
3eff9d11d2 Internal change
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Yash Katariya
0bc8f8abeb * Check if the device assignment is the same across input and output shardings.
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.

PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
jax authors
11896b68a2 Merge pull request #11429 from sharadmv:for-loop
PiperOrigin-RevId: 460318883
2022-07-11 15:52:42 -07:00
Benjamin Kramer
9e16efa98a Integrate LLVM at llvm/llvm-project@71c9757474
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3)

PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Sharad Vikram
9d610e2de6 Add loop invariant residual fixpoint test 2022-07-11 13:10:03 -07:00
jax authors
cc42c8091d Merge pull request #11406 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 460226570
2022-07-11 09:06:42 -07:00
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f Merge pull request #11410 from sharadmv:for-loop
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f Add loop-invariant residual optimization for for 2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
jax authors
dac310c221 Merge pull request #11421 from jakevdp:scalar-meta-nocopy
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a Resolve TODOs and add some more checks for the jax.Array path.
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00