13722 Commits

Author SHA1 Message Date
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
jax authors
30637d052b Merge pull request #13168 from hawkinsp:fixbuild
PiperOrigin-RevId: 487219149
2022-11-09 05:50:34 -08:00
Peter Hawkins
41c90a838e Add missing stablehlo dialect files to jaxlib build.
Unbreaks the build.
2022-11-09 13:37:49 +00:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Eugene Burmako
55996328f2 Introduce XlaLowering::stablehlo() and use it in associated APIs
See tests/api_test.py for usage examples.

At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.

This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.

PiperOrigin-RevId: 487144342
2022-11-08 22:50:06 -08:00
Skye Wanderman-Milne
df963bd72d Remove flaky Array defragmentation test check
PiperOrigin-RevId: 487120630
2022-11-08 20:06:36 -08:00
jax authors
0cf220f397 Merge pull request #13162 from jakevdp:bcoo-reshape
PiperOrigin-RevId: 487106270
2022-11-08 18:37:45 -08:00
Yash Katariya
53344b885d Don't create copies by device_putting a host local jax.Array if the sharding matches with the input.
PiperOrigin-RevId: 487090094
2022-11-08 17:02:23 -08:00
Jake VanderPlas
7d3b1d6439 [sparse] fix bcoo_reshape under jit 2022-11-08 17:00:25 -08:00
Skye Wanderman-Milne
0d2cd6dca1 [jax] Fix manual defragment method to work with Arrays
PiperOrigin-RevId: 487068409
2022-11-08 15:32:30 -08:00
jax authors
5e1d7cd52e Merge pull request #13032 from jakevdp:sharding-attr
PiperOrigin-RevId: 487061046
2022-11-08 15:01:23 -08:00
Jake VanderPlas
8fbf8da810 Declare Array.sharding & raise an error on tracers 2022-11-08 14:20:46 -08:00
jax authors
af017d44f5 Merge pull request #13153 from jakevdp:bcoo-reshape
PiperOrigin-RevId: 487046508
2022-11-08 14:11:51 -08:00
jax authors
768076eec4 Merge pull request #13157 from jakevdp:bcoo-astype
PiperOrigin-RevId: 487046458
2022-11-08 14:05:09 -08:00
Jake VanderPlas
7c0d0e67c8 [sparse] add support for BCOO.astype method 2022-11-08 13:30:22 -08:00
Jake VanderPlas
af956636b8 [sparse] fix bcoo_reshape when n_sparse=0 2022-11-08 12:00:24 -08:00
Yuxin Wu
96f6c1c9d4 Let is_user_frame ignore frames from stdlib.
When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame.

PiperOrigin-RevId: 486993924
2022-11-08 10:50:08 -08:00
jax authors
500cd859bf Merge pull request #13144 from LenaMartens:donate-no-more
PiperOrigin-RevId: 486979733
2022-11-08 09:57:44 -08:00
jax authors
3994ac30d5 Merge pull request #13145 from hawkinsp:pinv
PiperOrigin-RevId: 486935918
2022-11-08 06:39:54 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
lenamartens
e80c34d624 Don't donate arguments in jit/pmap/pjit when debug_nans=True. 2022-11-08 13:33:59 +00:00
jax authors
1e7e8e8d5c Merge pull request #13147 from hawkinsp:eyes
PiperOrigin-RevId: 486826532
2022-11-07 19:25:15 -08:00
jax authors
85f43dd902 Merge pull request #13061 from nouiz:test_doc
PiperOrigin-RevId: 486816419
2022-11-07 18:23:41 -08:00
jax authors
eb9e8c243a Merge pull request #13117 from 8bitmp3:move-multihost-multiprocess-toc
PiperOrigin-RevId: 486780726
2022-11-07 15:31:10 -08:00
jax authors
e00f7e7967 Merge pull request #13093 from PhilipVinc:patch-1
PiperOrigin-RevId: 486753425
2022-11-07 13:49:52 -08:00
Parker Schuh
2c1fe45997 Add UnloadedMeshExecutable to represent a MeshExecutable that is not loaded
on any physical devices for the purposes of serialization. This type is easier
to serialize because it has not yet been converted into arg-handlers.

Potential API:
```
  str, in_tree, out_tree = lowered.compile_and_serialize()
  exec = jax.experimental.load_serialized(str, in_tree, out_tree, backend)
  exec // identical to lowered.compile().
```

PiperOrigin-RevId: 486751141
2022-11-07 13:40:40 -08:00
Filippo Vicentini
793fb9b22c Fix issue in check_tree, so that custom_linear_solve supports
hax_aux=True when the vector and the aux are both pytrees.
2022-11-07 21:29:43 +01:00
Hyeontaek Lim
218305964f Fix GDA error message formatting
PiperOrigin-RevId: 486724647
2022-11-07 11:55:28 -08:00
Peter Hawkins
845f8df837 Avoid forming identity matrix in SVD JVP.
Set the default matmul precision in the SVD JVP, and use @ to express matmuls.
Also fix a flaky test failure in QR test on Mac ARM.
2022-11-07 13:55:45 -05:00
Kuangyuan Chen
b127b70e30 Remove static_argnums from AOT invocation.
Static args are not needed during invoking an AOT computation.

PiperOrigin-RevId: 486698420
2022-11-07 10:21:57 -08:00
Yash Katariya
da519f3b2c Check for ArrayImpl rather than sharding because this code is supposed to check for concrete Array until a shard_like primitive exists.
PiperOrigin-RevId: 486689809
2022-11-07 09:52:41 -08:00
jax authors
587885bbd3 Merge pull request #12077 from hawkinsp:docs
PiperOrigin-RevId: 486681964
2022-11-07 09:28:57 -08:00
jax authors
3da554d235 Merge pull request #13146 from jakevdp:fix-flake8
PiperOrigin-RevId: 486680771
2022-11-07 09:17:35 -08:00
Jake VanderPlas
b0e03fb747 Remove whitespace to fix flake8 2022-11-07 09:10:05 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
jax authors
042595bc4c Merge pull request #12890 from ROCmSoftwarePlatform:rocm_enable_multi_gpu_test
PiperOrigin-RevId: 486675297
2022-11-07 08:56:24 -08:00
jax authors
e9e014f432 Merge pull request #13140 from gnecula:clean_limitations
PiperOrigin-RevId: 486610612
2022-11-07 02:57:50 -08:00
George Necula
d9b1dc336d [jax2tf] Fixed jax2tf Limitations
Improved the documentation, and fixed a dot_general limitations for preferred_element_type on GPU
2022-11-07 11:26:39 +02:00
jax authors
8d59b0d47a Merge pull request #13108 from mattjj:djax-vmap2
PiperOrigin-RevId: 486534673
2022-11-06 17:27:53 -08:00
Matthew Johnson
f2f2faa4fa add a basic prototype of piles, behind jax_dynamic_shapes
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
Yash Katariya
2a262e9567 If the input to host_local_array_to_global_array is not fully addressable (i.e. not host local), return it as is.
Also if the input to `global_array_to_host_local_array` is fully addressable (i.e. host local), return it as is.

PiperOrigin-RevId: 486419066
2022-11-05 20:16:14 -07:00
Yash Katariya
3dbe10177e Remove device_indices method which is redundant because of the existence of devices_indices_map and is slower because pinging a cache for every device is not free.
PiperOrigin-RevId: 486405037
2022-11-05 17:33:37 -07:00
Tom Hennigan
5b45357a1d Make Sharding instances picklable.
PiperOrigin-RevId: 486386106
2022-11-05 13:36:52 -07:00
Yash Katariya
d9f3dc35ad Improve the sharding mismatch error message by adding the arg to the message too.
PiperOrigin-RevId: 486310996
2022-11-05 00:15:15 -07:00
jax authors
0c36f844d3 Merge pull request #13127 from mattjj:issue13124
PiperOrigin-RevId: 486296679
2022-11-04 21:44:17 -07:00
Yash Katariya
e161d20dc3 Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called.
PiperOrigin-RevId: 486294881
2022-11-04 21:25:46 -07:00
Matthew Johnson
190204ff7d fix jax.random.logits shape argument
fixes #13124
2022-11-04 19:51:39 -07:00
jax authors
2932c1ef06 Merge pull request #13122 from yejingxin:main
PiperOrigin-RevId: 486253810
2022-11-04 16:15:14 -07:00
jax authors
75b5ccc355 Merge pull request #13024 from treyra:patch-2
PiperOrigin-RevId: 486244906
2022-11-04 15:31:34 -07:00
Jingxin Ye
e6c88f2c58 update pytest.ini to print warning message for compilation_cache_test 2022-11-04 21:43:51 +00:00