5130 Commits

Author SHA1 Message Date
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
d1e26d9c5d Merge pull request #13139 from mattjj:djax-vmap4
PiperOrigin-RevId: 488458141
2022-11-14 13:48:28 -08:00
Peter Hawkins
da130cb074 Disable more tests under tsan/asan.
PiperOrigin-RevId: 488406459
2022-11-14 10:34:30 -08:00
jax authors
b086e73d36 Merge pull request #13189 from Ishticode:lcm_update
PiperOrigin-RevId: 488383042
2022-11-14 09:10:39 -08:00
Peter Hawkins
aa658bde6f Disable asan/tsan for a number of slow tests.
PiperOrigin-RevId: 488356786
2022-11-14 07:12:16 -08:00
Peter Hawkins
40e81c3a86 Revert: Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

This PR appears to have caused numerical problems in downstream tests.

PiperOrigin-RevId: 487942754
2022-11-11 16:24:05 -08:00
Sharad Vikram
e15619ceab Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Ishtiaq Hussain
09f62dec3c Moved abs to inputs of lcm and added specific test 2022-11-11 22:31:06 +00:00
Peter Hawkins
7c3fb81310 Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

PiperOrigin-RevId: 487903227
2022-11-11 13:28:48 -08:00
Peter Hawkins
047974dd0c Be more economical when computing the JVP of the SVD of non-square matrices.
(Note this isn't a regression from #13147: the previous change did not alter the order of operations.)

PiperOrigin-RevId: 487896154
2022-11-11 12:55:53 -08:00
Peter Hawkins
c9ebf60f4e Compute the JVP of jnp.linalg.pinv more economically for non-square matrices.
The order of the matrix products matters.

PiperOrigin-RevId: 487879202
2022-11-11 11:45:02 -08:00
jax authors
995736119e Merge pull request #13198 from patrick-kidger:prng-isinstance
PiperOrigin-RevId: 487859092
2022-11-11 10:31:24 -08:00
Patrick Kidger
d2afa84a6e PRNGKeyArray is now a virtual subclass of ndarray 2022-11-11 08:04:38 -08:00
jax authors
ce85106578 Merge pull request #13193 from tlu7:bcsr-fromdense-batching
PiperOrigin-RevId: 487810600
2022-11-11 06:51:30 -08:00
Jake VanderPlas
90dc008340 [sparse] add bcoo_gather & support for sparse indexing 2022-11-11 04:25:14 -08:00
Peter Hawkins
a13541441b Reenable a TPU test now that the compiler bug is fixed.
PiperOrigin-RevId: 487705048
2022-11-10 19:38:01 -08:00
Tianjian Lu
332fced0cc sparse] BCSR batching rule.
[Co-authored-by: Jake Vanderplas: <vanderplas@google.com>
2022-11-10 19:33:32 -08:00
Parker Schuh
4a3b7f16ff Change pickling for jax.sharding to not serialize device ids.
PiperOrigin-RevId: 487700467
2022-11-10 19:05:02 -08:00
Tianjian Lu
311fb24ff9 [sparse] Add BCSR from_scipy_sparse.
Co-authored-by: Jake Vanderplas <vanderplas@google.com>
2022-11-10 16:44:59 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
0ebb6b4215 Merge pull request #13180 from jakevdp:bcoo-slice
PiperOrigin-RevId: 487568853
2022-11-10 10:04:35 -08:00
Yash Katariya
cc41ee85c4 Mark scipy_signal_test and sparse_test optonly because it times out under debug mode.
PiperOrigin-RevId: 487533356
2022-11-10 07:38:58 -08:00
Yash Katariya
71360edf90 Bump the shard count for TPU to avoid timeouts
PiperOrigin-RevId: 487421018
2022-11-09 20:32:12 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
jax authors
b36afc5b0d Merge pull request #13177 from jakevdp:bcoo-dynamic-slice
PiperOrigin-RevId: 487390430
2022-11-09 17:29:30 -08:00
Sharad Vikram
3731e446c0 Set default layout for Python callback
PiperOrigin-RevId: 487388682
2022-11-09 17:18:49 -08:00
Yash Katariya
f9bbd585b9 Improve the error message when @pjit (with no {in_axis|out_axis}_resources is used without jax.Array enabled.
PiperOrigin-RevId: 487380328
2022-11-09 16:38:00 -08:00
Jake VanderPlas
4c4f2a3ad2 [sparse] support strides in bcoo_slice 2022-11-09 15:03:21 -08:00
Jake VanderPlas
46d9cac122 [sparse] bcoo_dynamic_slice: remove unnecessary padding from output 2022-11-09 13:56:18 -08:00
Jake VanderPlas
0c3e330148 [sparse] fix shape bug in bcoo_transpose 2022-11-09 12:53:13 -08:00
Felix Chern
8ac7422e26 [JAX] Disables large k test cases in ann_test.
Will investigate probability properties for the corner cases in the future.

PiperOrigin-RevId: 487302143
2022-11-09 11:32:47 -08:00
jax authors
63e3152764 Merge pull request #13160 from jakevdp:bcoo-squeeze
PiperOrigin-RevId: 487280563
2022-11-09 10:18:22 -08:00
jax authors
f697b8e087 Merge pull request #13166 from LenaMartens:checking-keys
PiperOrigin-RevId: 487267267
2022-11-09 09:31:39 -08:00
lenamartens
053b8b5bcd Checkify: fix nan_checks+PRNGKeys - a PRNGKey is never NaN!
Add a guard to the nan_error_rule to not call jnp.isnan on keys.
2022-11-09 17:08:21 +00:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Eugene Burmako
55996328f2 Introduce XlaLowering::stablehlo() and use it in associated APIs
See tests/api_test.py for usage examples.

At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.

This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.

PiperOrigin-RevId: 487144342
2022-11-08 22:50:06 -08:00
Skye Wanderman-Milne
df963bd72d Remove flaky Array defragmentation test check
PiperOrigin-RevId: 487120630
2022-11-08 20:06:36 -08:00
Jake VanderPlas
4255697610 [sparse] add bcoo_squeeze function 2022-11-08 18:16:20 -08:00
Jake VanderPlas
7d3b1d6439 [sparse] fix bcoo_reshape under jit 2022-11-08 17:00:25 -08:00
Skye Wanderman-Milne
0d2cd6dca1 [jax] Fix manual defragment method to work with Arrays
PiperOrigin-RevId: 487068409
2022-11-08 15:32:30 -08:00
jax authors
af017d44f5 Merge pull request #13153 from jakevdp:bcoo-reshape
PiperOrigin-RevId: 487046508
2022-11-08 14:11:51 -08:00
Jake VanderPlas
7c0d0e67c8 [sparse] add support for BCOO.astype method 2022-11-08 13:30:22 -08:00
Jake VanderPlas
af956636b8 [sparse] fix bcoo_reshape when n_sparse=0 2022-11-08 12:00:24 -08:00
Yuxin Wu
96f6c1c9d4 Let is_user_frame ignore frames from stdlib.
When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame.

PiperOrigin-RevId: 486993924
2022-11-08 10:50:08 -08:00
Matthew Johnson
0b463efb70 tighten up vmap w/ piles: require pile_axis in_axes/out_axes 2022-11-08 10:27:55 -08:00
jax authors
500cd859bf Merge pull request #13144 from LenaMartens:donate-no-more
PiperOrigin-RevId: 486979733
2022-11-08 09:57:44 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
lenamartens
e80c34d624 Don't donate arguments in jit/pmap/pjit when debug_nans=True. 2022-11-08 13:33:59 +00:00