Yash Katariya
e6c4d4a30e
Add docstrings for Sharding
classes. Right now I am only documenting Sharding
, XLACompatibleSharding
, MeshPspecSharding
and SingleDeviceSharding
.
...
Also moving jax_array_migration guide to reference documentation.
PiperOrigin-RevId: 488489503
2022-11-14 15:47:46 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
d1e26d9c5d
Merge pull request #13139 from mattjj:djax-vmap4
...
PiperOrigin-RevId: 488458141
2022-11-14 13:48:28 -08:00
jax authors
f715ca6824
Merge pull request #13236 from skye:build_chat
...
PiperOrigin-RevId: 488432432
2022-11-14 12:06:04 -08:00
Skye Wanderman-Milne
5da7976093
Send message to internal chat room on Cloud TPU CI failure
2022-11-14 19:44:45 +00:00
Peter Hawkins
da130cb074
Disable more tests under tsan/asan.
...
PiperOrigin-RevId: 488406459
2022-11-14 10:34:30 -08:00
John QiangZhang
7c6a65bee2
Add tf.get_current_name_scope() as prefix of name_stack during tf execution.
...
PiperOrigin-RevId: 488399560
2022-11-14 10:12:04 -08:00
jax authors
b086e73d36
Merge pull request #13189 from Ishticode:lcm_update
...
PiperOrigin-RevId: 488383042
2022-11-14 09:10:39 -08:00
jax authors
4f15563b65
Merge pull request #13234 from hawkinsp:docs
...
PiperOrigin-RevId: 488369070
2022-11-14 08:11:36 -08:00
Peter Hawkins
ce17ce0550
Mention in the pmap() documentation that all devices must be identical.
...
Fixes https://github.com/google/jax/issues/13203
2022-11-14 10:43:53 -05:00
Peter Hawkins
ebd9840e1f
Add several recent changes to the CHANGELOG.
...
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Yash Katariya
7600cc8a8e
Make jax.Array default to False for external colab.
...
PiperOrigin-RevId: 488360010
2022-11-14 07:28:00 -08:00
Peter Hawkins
aa658bde6f
Disable asan/tsan for a number of slow tests.
...
PiperOrigin-RevId: 488356786
2022-11-14 07:12:16 -08:00
Peter Hawkins
40e81c3a86
Revert: Use pinv to compute lstsq.
...
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
This PR appears to have caused numerical problems in downstream tests.
PiperOrigin-RevId: 487942754
2022-11-11 16:24:05 -08:00
jax authors
d4cc5882d2
Merge pull request #13215 from ROCmSoftwarePlatform:sytrd
...
PiperOrigin-RevId: 487940256
2022-11-11 16:11:27 -08:00
Sharad Vikram
e15619ceab
Convert string axis name into tuple of strings in Mesh constructor
...
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Yash Katariya
6897d37562
Add docstrings for jax.Array APIs make_array_from_callback
and make_array_from_single_device_arrays
.
...
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
Rahul Batra
31d8f62826
Sytrd solver and SytrdDescriptor should NOT be CUDA only
2022-11-11 22:41:51 +00:00
Ishtiaq Hussain
09f62dec3c
Moved abs to inputs of lcm and added specific test
2022-11-11 22:31:06 +00:00
jax authors
19d76a7818
Merge pull request #13212 from sharadmv:fix-changelog
...
PiperOrigin-RevId: 487915320
2022-11-11 14:18:19 -08:00
Sharad Vikram
4bdfdd7363
Update changelog w/ info about deleting jax_experimental_name_stack
2022-11-11 14:02:30 -08:00
Peter Hawkins
7c3fb81310
Use pinv to compute lstsq.
...
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
PiperOrigin-RevId: 487903227
2022-11-11 13:28:48 -08:00
Peter Hawkins
047974dd0c
Be more economical when computing the JVP of the SVD of non-square matrices.
...
(Note this isn't a regression from #13147 : the previous change did not alter the order of operations.)
PiperOrigin-RevId: 487896154
2022-11-11 12:55:53 -08:00
Peter Hawkins
c9ebf60f4e
Compute the JVP of jnp.linalg.pinv more economically for non-square matrices.
...
The order of the matrix products matters.
PiperOrigin-RevId: 487879202
2022-11-11 11:45:02 -08:00
jax authors
995736119e
Merge pull request #13198 from patrick-kidger:prng-isinstance
...
PiperOrigin-RevId: 487859092
2022-11-11 10:31:24 -08:00
jax authors
c359c7976b
Merge pull request #13196 from jakevdp:simpler-sparsify
...
PiperOrigin-RevId: 487837115
2022-11-11 09:01:55 -08:00
Patrick Kidger
d2afa84a6e
PRNGKeyArray is now a virtual subclass of ndarray
2022-11-11 08:04:38 -08:00
Felix Chern
10e6fe8cde
Correct norm in ann.py doc.
...
PiperOrigin-RevId: 487814084
2022-11-11 07:08:54 -08:00
jax authors
ce85106578
Merge pull request #13193 from tlu7:bcsr-fromdense-batching
...
PiperOrigin-RevId: 487810600
2022-11-11 06:51:30 -08:00
jax authors
71f92a7cd5
Merge pull request #13182 from canyon289:docs_update
...
PiperOrigin-RevId: 487808462
2022-11-11 06:44:29 -08:00
jax authors
bdf3bd5472
Merge pull request #13155 from jakevdp:bcoo-gather
...
PiperOrigin-RevId: 487808456
2022-11-11 06:37:30 -08:00
Jake VanderPlas
ea25b79b87
[sparse] streamline sparse rules for standard primitives
2022-11-11 04:50:33 -08:00
Jake VanderPlas
90dc008340
[sparse] add bcoo_gather & support for sparse indexing
2022-11-11 04:25:14 -08:00
Peter Hawkins
a13541441b
Reenable a TPU test now that the compiler bug is fixed.
...
PiperOrigin-RevId: 487705048
2022-11-10 19:38:01 -08:00
Tianjian Lu
332fced0cc
sparse] BCSR batching rule.
...
[Co-authored-by: Jake Vanderplas: <vanderplas@google.com>
2022-11-10 19:33:32 -08:00
jax authors
dc0d7ba368
Merge pull request #13202 from google:yashk2810-patch-18
...
PiperOrigin-RevId: 487701933
2022-11-10 19:15:29 -08:00
Yash Katariya
f0c0689a8a
Remove internal information
2022-11-10 19:09:19 -08:00
Parker Schuh
4a3b7f16ff
Change pickling for jax.sharding to not serialize device ids.
...
PiperOrigin-RevId: 487700467
2022-11-10 19:05:02 -08:00
jax authors
f9d7a6ae20
Merge pull request #13197 from google:yashk2810-patch-17
...
PiperOrigin-RevId: 487687006
2022-11-10 17:56:29 -08:00
Yash Katariya
73935a5bd1
Update jax_array_migration.md
2022-11-10 17:23:16 -08:00
jax authors
c318f771cb
Merge pull request #13185 from tlu7:bcsr-from-scipy
...
PiperOrigin-RevId: 487680265
2022-11-10 17:18:44 -08:00
Yash Katariya
aa66b939f9
Fix the docs build
2022-11-10 17:08:57 -08:00
Yash Katariya
b49a1bda15
Add jax.Array migration doc to OSS
...
PiperOrigin-RevId: 487673643
2022-11-10 16:46:30 -08:00
Tianjian Lu
311fb24ff9
[sparse] Add BCSR from_scipy_sparse.
...
Co-authored-by: Jake Vanderplas <vanderplas@google.com>
2022-11-10 16:44:59 -08:00
Peter Hawkins
352b042fe9
Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
...
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Sharad Vikram
74b136e62c
Delete jax_experimental_name_stack
flag
...
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
0ebb6b4215
Merge pull request #13180 from jakevdp:bcoo-slice
...
PiperOrigin-RevId: 487568853
2022-11-10 10:04:35 -08:00
Yash Katariya
cc41ee85c4
Mark scipy_signal_test and sparse_test optonly
because it times out under debug mode.
...
PiperOrigin-RevId: 487533356
2022-11-10 07:38:58 -08:00
Yash Katariya
71360edf90
Bump the shard count for TPU to avoid timeouts
...
PiperOrigin-RevId: 487421018
2022-11-09 20:32:12 -08:00
Peter Hawkins
e42e52d4aa
Rename test flag --num_generated_cases to --jax_num_generated_cases.
...
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00