jax authors
0b6657e471
Merge pull request #11556 from RuffaloLavoisier:tYpO
...
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
4870710891
Enable debugging callbacks with pjit on TPU
...
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
jax authors
8a67734e7b
Merge pull request #11579 from sharadmv:fix-effects
...
PiperOrigin-RevId: 462478510
2022-07-21 15:02:46 -07:00
jax authors
24134ec2a5
Merge pull request #11425 from pschuh:pjit-bugfix
...
PiperOrigin-RevId: 462469178
2022-07-21 14:20:00 -07:00
jax authors
540ee56ff2
Merge pull request #11576 from jakevdp:searchsorted-alt
...
PiperOrigin-RevId: 462461853
2022-07-21 13:47:43 -07:00
Sharad Vikram
d6c172d53e
Fix PE not allowing double JIT-ted effectful functions
2022-07-21 11:55:48 -07:00
jax authors
a4e754849e
Merge pull request #11543 from nvcastet:fix_multigpu_test
...
PiperOrigin-RevId: 462418103
2022-07-21 10:27:57 -07:00
Jake VanderPlas
10411bfeae
jnp.searchsorted: add optional method argument to control implementation
2022-07-21 09:40:18 -07:00
George Necula
07fcf79324
jax.mask and jax.shapecheck are being deprecated
...
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
jax authors
be6db2e619
Merge pull request #10775 from pschuh:mlir-caching
...
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
Parker Schuh
6c4da65af4
Add treedef_is_strict_leaf to fix _prefix_error's semantics.
...
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.
Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
2022-07-20 17:02:59 -07:00
Kuangyuan Chen
c0ec3b33e6
Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
...
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Yash Katariya
d8cbb29d14
OpSharding doesn't have __eq__
defined on it. Don't check sharding equality using opsharding until it does support that.
...
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
jax authors
d0162cd37e
Merge pull request #11533 from ROCmSoftwarePlatform:rocm_disable_lobpcg_test
...
PiperOrigin-RevId: 462182051
2022-07-20 10:56:54 -07:00
Adam Paszke
117da44712
Internal change
...
PiperOrigin-RevId: 462110048
2022-07-20 04:31:21 -07:00
RuffaloLavoisier
9f770425ac
Correct spelling on word
2022-07-20 18:57:12 +09:00
jax authors
7f1813c5e3
Merge pull request #11539 from gnecula:ds_reshape
...
PiperOrigin-RevId: 462061742
2022-07-19 23:13:03 -07:00
Matthew Johnson
7cb5c2447e
[dynamic-shapes] fix minor bint bugs
...
Co-authored-by: Eugene Burmako <burmako@google.com>
2022-07-19 16:38:40 -07:00
Jake VanderPlas
9090dd179d
jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0
2022-07-19 13:57:49 -07:00
Nicolas Castet
7589c6d7f0
Fix MultiProcessGpuTest test
...
Since MultiProcessGpuTest was using 'shell=True', only the first element
of the args was executed (i.e. python). Therefore the spawn processes
never executed jax code.
Fix the test and make sure 'jax.distributed' initialize by checking
jax.device_count().
2022-07-19 11:13:32 -05:00
George Necula
c45fe49821
[dynamic-shapes] Add typechecking rule for reshape
2022-07-19 15:10:14 +02:00
Parker Schuh
704f125c88
Add caching to trace_to_subjaxpr_dynamic2
.
...
This allows the MLIR lowering code to cache call lowerings.
example output:
```
module @jit_fun.0 {
func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
%1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
return %1 : tensor<4x8xf32>
}
func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
}
```
If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
Rohit Santhanam
c4b37ad8a1
[ROCm] Disable lobpcg unit test for ROCm until performance issue is resolved.
2022-07-18 18:12:19 +00:00
Rohit Santhanam
235ea7059c
[ROCm] Disable new array_interoperability dlpack tests.
2022-07-17 04:48:11 +00:00
Jake VanderPlas
2f4c485a54
Add dlpack support to device_array and jax.numpy
2022-07-15 17:31:11 -07:00
Yash Katariya
90687cc1ff
Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation
except of AUTO
and UNSPECIFIED
(see below for these 2).
...
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
Jake VanderPlas
c1549a0a16
[sparse] make sparse objects compatible with jax.jit.lower()
2022-07-15 09:58:31 -07:00
Tom Hennigan
10720258ea
Reduce the verbosity of treedef printing for custom nodes.
...
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.
Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:
>>> params = { .. some params .. }
>>> f = jax.jit(..).lower(params).compile()
>>> f(params) # fine
>>> params['some_new_thing'] = something
>>> f(params)
TypeError: function compiled for {treedef}, called with {treedef}.
PiperOrigin-RevId: 461190971
2022-07-15 07:14:28 -07:00
George Necula
777c129dfb
[dynamic-shapes] Split dynamic_api_test.py
...
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
Jake VanderPlas
0f14943524
lax_numpy_test: make compatible with numpy 1.24-dev
2022-07-14 14:35:10 -07:00
jax authors
4d1c6dfefa
Merge pull request #11469 from jakevdp:fix-rem-jvp
...
PiperOrigin-RevId: 460517781
2022-07-12 11:53:27 -07:00
Jake VanderPlas
daf6e3b065
BUG: fix jvp rule for lax.rem
2022-07-12 09:50:42 -07:00
jax authors
3eff9d11d2
Internal change
...
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Yash Katariya
0bc8f8abeb
* Check if the device assignment is the same across input and output shardings.
...
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.
PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
jax authors
11896b68a2
Merge pull request #11429 from sharadmv:for-loop
...
PiperOrigin-RevId: 460318883
2022-07-11 15:52:42 -07:00
Benjamin Kramer
9e16efa98a
Integrate LLVM at llvm/llvm-project@71c9757474
...
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3 )
PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Sharad Vikram
9d610e2de6
Add loop invariant residual fixpoint test
2022-07-11 13:10:03 -07:00
jax authors
cc42c8091d
Merge pull request #11406 from jakevdp:bcoo-add-batchdim
...
PiperOrigin-RevId: 460226570
2022-07-11 09:06:42 -07:00
Peter Hawkins
64e0b5d801
Increase bazel sharding of GPU tests.
...
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec
Rollback of HCB GPU custom call due to internal failures
...
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
ed51c65576
Merge pull request #11405 from mattjj:djax-vmap
...
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c
[dynamic-shapes] start basic vmap compatibility
2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323
Move _get_array_mapping from gda.py to pxla.py
...
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f
Merge pull request #11410 from sharadmv:for-loop
...
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f
Add loop-invariant residual optimization for for
2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0
Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
...
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
jax authors
dac310c221
Merge pull request #11421 from jakevdp:scalar-meta-nocopy
...
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a
Resolve TODOs and add some more checks for the jax.Array path.
...
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
YouJiacheng
7c707832aa
Enable CustomCall implementation on GPU
2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf
Use asarray rather than array in ScalarMeta
...
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00