Peter Hawkins
1cead779a3
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
...
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
jax authors
30637d052b
Merge pull request #13168 from hawkinsp:fixbuild
...
PiperOrigin-RevId: 487219149
2022-11-09 05:50:34 -08:00
Peter Hawkins
41c90a838e
Add missing stablehlo dialect files to jaxlib build.
...
Unbreaks the build.
2022-11-09 13:37:49 +00:00
Tianjian Lu
3b1ddf2881
[linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
...
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Eugene Burmako
55996328f2
Introduce XlaLowering::stablehlo() and use it in associated APIs
...
See tests/api_test.py for usage examples.
At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.
This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.
PiperOrigin-RevId: 487144342
2022-11-08 22:50:06 -08:00
Skye Wanderman-Milne
df963bd72d
Remove flaky Array defragmentation test check
...
PiperOrigin-RevId: 487120630
2022-11-08 20:06:36 -08:00
jax authors
0cf220f397
Merge pull request #13162 from jakevdp:bcoo-reshape
...
PiperOrigin-RevId: 487106270
2022-11-08 18:37:45 -08:00
Yash Katariya
53344b885d
Don't create copies by device_putting a host local jax.Array if the sharding matches with the input.
...
PiperOrigin-RevId: 487090094
2022-11-08 17:02:23 -08:00
Jake VanderPlas
7d3b1d6439
[sparse] fix bcoo_reshape under jit
2022-11-08 17:00:25 -08:00
Skye Wanderman-Milne
0d2cd6dca1
[jax] Fix manual defragment method to work with Arrays
...
PiperOrigin-RevId: 487068409
2022-11-08 15:32:30 -08:00
jax authors
5e1d7cd52e
Merge pull request #13032 from jakevdp:sharding-attr
...
PiperOrigin-RevId: 487061046
2022-11-08 15:01:23 -08:00
Jake VanderPlas
8fbf8da810
Declare Array.sharding & raise an error on tracers
2022-11-08 14:20:46 -08:00
jax authors
af017d44f5
Merge pull request #13153 from jakevdp:bcoo-reshape
...
PiperOrigin-RevId: 487046508
2022-11-08 14:11:51 -08:00
jax authors
768076eec4
Merge pull request #13157 from jakevdp:bcoo-astype
...
PiperOrigin-RevId: 487046458
2022-11-08 14:05:09 -08:00
Jake VanderPlas
7c0d0e67c8
[sparse] add support for BCOO.astype method
2022-11-08 13:30:22 -08:00
Jake VanderPlas
af956636b8
[sparse] fix bcoo_reshape when n_sparse=0
2022-11-08 12:00:24 -08:00
Yuxin Wu
96f6c1c9d4
Let is_user_frame ignore frames from stdlib.
...
When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame.
PiperOrigin-RevId: 486993924
2022-11-08 10:50:08 -08:00
jax authors
500cd859bf
Merge pull request #13144 from LenaMartens:donate-no-more
...
PiperOrigin-RevId: 486979733
2022-11-08 09:57:44 -08:00
jax authors
3994ac30d5
Merge pull request #13145 from hawkinsp:pinv
...
PiperOrigin-RevId: 486935918
2022-11-08 06:39:54 -08:00
Peter Hawkins
ab8cde9ed4
Add support for the hermitian option on jnp.linalg.pinv.
...
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
lenamartens
e80c34d624
Don't donate arguments in jit/pmap/pjit when debug_nans=True.
2022-11-08 13:33:59 +00:00
jax authors
1e7e8e8d5c
Merge pull request #13147 from hawkinsp:eyes
...
PiperOrigin-RevId: 486826532
2022-11-07 19:25:15 -08:00
jax authors
85f43dd902
Merge pull request #13061 from nouiz:test_doc
...
PiperOrigin-RevId: 486816419
2022-11-07 18:23:41 -08:00
jax authors
eb9e8c243a
Merge pull request #13117 from 8bitmp3:move-multihost-multiprocess-toc
...
PiperOrigin-RevId: 486780726
2022-11-07 15:31:10 -08:00
jax authors
e00f7e7967
Merge pull request #13093 from PhilipVinc:patch-1
...
PiperOrigin-RevId: 486753425
2022-11-07 13:49:52 -08:00
Parker Schuh
2c1fe45997
Add UnloadedMeshExecutable to represent a MeshExecutable that is not loaded
...
on any physical devices for the purposes of serialization. This type is easier
to serialize because it has not yet been converted into arg-handlers.
Potential API:
```
str, in_tree, out_tree = lowered.compile_and_serialize()
exec = jax.experimental.load_serialized(str, in_tree, out_tree, backend)
exec // identical to lowered.compile().
```
PiperOrigin-RevId: 486751141
2022-11-07 13:40:40 -08:00
Filippo Vicentini
793fb9b22c
Fix issue in check_tree
, so that custom_linear_solve supports
...
hax_aux=True when the vector and the aux are both pytrees.
2022-11-07 21:29:43 +01:00
Hyeontaek Lim
218305964f
Fix GDA error message formatting
...
PiperOrigin-RevId: 486724647
2022-11-07 11:55:28 -08:00
Peter Hawkins
845f8df837
Avoid forming identity matrix in SVD JVP.
...
Set the default matmul precision in the SVD JVP, and use @ to express matmuls.
Also fix a flaky test failure in QR test on Mac ARM.
2022-11-07 13:55:45 -05:00
Kuangyuan Chen
b127b70e30
Remove static_argnums from AOT invocation.
...
Static args are not needed during invoking an AOT computation.
PiperOrigin-RevId: 486698420
2022-11-07 10:21:57 -08:00
Yash Katariya
da519f3b2c
Check for ArrayImpl
rather than sharding
because this code is supposed to check for concrete Array until a shard_like
primitive exists.
...
PiperOrigin-RevId: 486689809
2022-11-07 09:52:41 -08:00
jax authors
587885bbd3
Merge pull request #12077 from hawkinsp:docs
...
PiperOrigin-RevId: 486681964
2022-11-07 09:28:57 -08:00
jax authors
3da554d235
Merge pull request #13146 from jakevdp:fix-flake8
...
PiperOrigin-RevId: 486680771
2022-11-07 09:17:35 -08:00
Jake VanderPlas
b0e03fb747
Remove whitespace to fix flake8
2022-11-07 09:10:05 -08:00
Peter Hawkins
cd84eb10a6
Add a number of missing function cross-references in the docs.
2022-11-07 12:00:26 -05:00
jax authors
042595bc4c
Merge pull request #12890 from ROCmSoftwarePlatform:rocm_enable_multi_gpu_test
...
PiperOrigin-RevId: 486675297
2022-11-07 08:56:24 -08:00
jax authors
e9e014f432
Merge pull request #13140 from gnecula:clean_limitations
...
PiperOrigin-RevId: 486610612
2022-11-07 02:57:50 -08:00
George Necula
d9b1dc336d
[jax2tf] Fixed jax2tf Limitations
...
Improved the documentation, and fixed a dot_general limitations for preferred_element_type on GPU
2022-11-07 11:26:39 +02:00
jax authors
8d59b0d47a
Merge pull request #13108 from mattjj:djax-vmap2
...
PiperOrigin-RevId: 486534673
2022-11-06 17:27:53 -08:00
Matthew Johnson
f2f2faa4fa
add a basic prototype of piles, behind jax_dynamic_shapes
...
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
Yash Katariya
2a262e9567
If the input to host_local_array_to_global_array
is not fully addressable (i.e. not host local), return it as is.
...
Also if the input to `global_array_to_host_local_array` is fully addressable (i.e. host local), return it as is.
PiperOrigin-RevId: 486419066
2022-11-05 20:16:14 -07:00
Yash Katariya
3dbe10177e
Remove device_indices
method which is redundant because of the existence of devices_indices_map
and is slower because pinging a cache for every device is not free.
...
PiperOrigin-RevId: 486405037
2022-11-05 17:33:37 -07:00
Tom Hennigan
5b45357a1d
Make Sharding instances picklable.
...
PiperOrigin-RevId: 486386106
2022-11-05 13:36:52 -07:00
Yash Katariya
d9f3dc35ad
Improve the sharding mismatch error message by adding the arg to the message too.
...
PiperOrigin-RevId: 486310996
2022-11-05 00:15:15 -07:00
jax authors
0c36f844d3
Merge pull request #13127 from mattjj:issue13124
...
PiperOrigin-RevId: 486296679
2022-11-04 21:44:17 -07:00
Yash Katariya
e161d20dc3
Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called.
...
PiperOrigin-RevId: 486294881
2022-11-04 21:25:46 -07:00
Matthew Johnson
190204ff7d
fix jax.random.logits shape argument
...
fixes #13124
2022-11-04 19:51:39 -07:00
jax authors
2932c1ef06
Merge pull request #13122 from yejingxin:main
...
PiperOrigin-RevId: 486253810
2022-11-04 16:15:14 -07:00
jax authors
75b5ccc355
Merge pull request #13024 from treyra:patch-2
...
PiperOrigin-RevId: 486244906
2022-11-04 15:31:34 -07:00
Jingxin Ye
e6c88f2c58
update pytest.ini to print warning message for compilation_cache_test
2022-11-04 21:43:51 +00:00