10550 Commits

Author SHA1 Message Date
Jake VanderPlas
da3aaa1960 Add deprecation warning to JaxTestCase and JaxTestLoader 2022-02-17 14:58:58 -08:00
jax authors
e545daa1e5 Merge pull request #9621 from hawkinsp:docs3
PiperOrigin-RevId: 429401742
2022-02-17 14:23:13 -08:00
Peter Hawkins
3e5ecfe363 Add jax.distributed and jax.dlpack to the docs.
Reorder the doc modules into something closer to alphabetical order.

Add missing functions from jax.scipy.linalg and jax.scipy.signal to the docs.
2022-02-17 16:10:07 -05:00
jax authors
54a6e4dad3 Merge pull request #9422 from yotarok:signal_stft
PiperOrigin-RevId: 429377655
2022-02-17 12:46:12 -08:00
jax authors
5bb140f659 Merge pull request #9615 from froystig:jnp-expand-dims-error
PiperOrigin-RevId: 429376858
2022-02-17 12:45:55 -08:00
Yash Katariya
6bb58e6fde Xmap GDA integration. Non-contiguous mesh is allowed!
PiperOrigin-RevId: 429376557
2022-02-17 12:41:12 -08:00
jax authors
83a50202e2 Merge pull request #9616 from hawkinsp:doc
PiperOrigin-RevId: 429376341
2022-02-17 12:36:21 -08:00
Peter Hawkins
d704c151fa Clarify the NVidia driver version requirements. 2022-02-17 14:37:29 -05:00
Roy Frostig
35fab1a95a err on repeated axes to expand_dims, as numpy does 2022-02-17 11:27:20 -08:00
jax authors
032bfe0915 Merge pull request #9609 from froystig:prng-array-stack
PiperOrigin-RevId: 429342174
2022-02-17 10:25:29 -08:00
Lena Martens
73f23705d0 Checkify: explicitly export public API, hide private symbols.
PiperOrigin-RevId: 429277551
2022-02-17 04:30:59 -08:00
Adam Paszke
57f423203d Fix uninitialized axis_env error when MLIR lowering is disabled
PiperOrigin-RevId: 429267926
2022-02-17 03:28:13 -08:00
jax authors
15295a82da Merge pull request #9544 from SaturdayGenfo:adds-matrix-sqrt
PiperOrigin-RevId: 429264231
2022-02-17 03:02:47 -08:00
Yash Katariya
bd2a6a07c8 del self._old_env so that you can use with global_mesh multiple times. This was Matt's idea.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 429229193
2022-02-16 23:15:36 -08:00
Yotaro Kubo
e085370ec4 Add some functions for spectral analysis.
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Yash Katariya
a83695a783 Merge mesh and Mesh. Make Mesh a context manager + class so that it can be used in the following ways:
```
global_mesh = Mesh(devices, axis_names)
with global_mesh:
  ...

OR

with Mesh(devices, axis_names) as global_mesh:
  ...

OR

global_mesh = Mesh(devices, axis_names)
with global_mesh as m:
  ...
```
PiperOrigin-RevId: 429201126
2022-02-16 19:44:38 -08:00
Leello Tadesse Dadi
cb732323f3 adds jax.scipy.linalg.sqrtm 2022-02-16 22:33:47 +01:00
Leello Tadesse Dadi
514d8883ce adds jax.scipy.schur 2022-02-16 22:33:37 +01:00
jax authors
e25259e596 Merge pull request #9605 from hawkinsp:pickle
PiperOrigin-RevId: 429111539
2022-02-16 12:20:54 -08:00
Peter Hawkins
901d459e0d Add cloudpickle as a test requirement.
We have at least one test that tests pickling JAX objects.
2022-02-16 15:04:56 -05:00
jax authors
052b5c36b5 Merge pull request #9602 from LenaMartens:changelist/429094321
PiperOrigin-RevId: 429102776
2022-02-16 11:46:00 -08:00
jax authors
e1fd6304d8 Merge pull request #9493 from mattjj:better-pjit-pytree-prefix-error
PiperOrigin-RevId: 429097694
2022-02-16 11:27:05 -08:00
Lena Martens
758c721605 Checkify: fix nd-error case when array only has 1 element. 2022-02-16 19:15:48 +00:00
Yash Katariya
94aade035a Set the aval inside _create_local_shards iteration. Since we are iterating over device buffers there, why pay the cost twice!
PiperOrigin-RevId: 429058996
2022-02-16 08:52:24 -08:00
jax authors
c97fefcc6c Merge pull request #9597 from gnecula:tf_take_along_axis
PiperOrigin-RevId: 429050829
2022-02-16 08:14:39 -08:00
jax authors
35082fce53 Merge pull request #9595 from gnecula:tf_metadata
PiperOrigin-RevId: 429041948
2022-02-16 07:29:48 -08:00
George Necula
1928f6e6b1 [jax2tf] Fixes shape polymorphism for jnp.take_along_axes
Fixes: #9552
2022-02-16 16:16:08 +01:00
George Necula
461b37b2a8 [jax2tf] Fixed stale documentation about XLA metadata.
jax2tf does not yet support passing source location information
through to TF. The mechanism is partially implemented but disabled.
Here we remove misleading documentation that suggests the mechanism
is enabled.
2022-02-16 15:41:37 +01:00
jax authors
c49fb9c280 Merge pull request #9561 from pschuh:opt-barrier
PiperOrigin-RevId: 428877597
2022-02-15 14:34:26 -08:00
jax authors
95c486a5a3 Merge pull request #9585 from jakevdp:typos
PiperOrigin-RevId: 428871943
2022-02-15 14:16:43 -08:00
Parker Schuh
662c4416a3
Merge branch 'main' into opt-barrier 2022-02-15 14:16:20 -08:00
jax authors
1da8b50238 Merge pull request #9587 from jakevdp:fix-type-promotion
PiperOrigin-RevId: 428871857
2022-02-15 14:11:55 -08:00
Jake VanderPlas
e82f232ea9 Fix nomenclature in type promotion doc 2022-02-15 13:26:06 -08:00
Lena Martens
b15c7f609a Checkify: fix check_error of nd-error.
PiperOrigin-RevId: 428857813
2022-02-15 13:12:53 -08:00
Jake VanderPlas
de360833e0 type promotion design doc: minor typos 2022-02-15 12:17:18 -08:00
jax authors
bf3c658114 Merge pull request #9573 from hawkinsp:cuda
PiperOrigin-RevId: 428831952
2022-02-15 11:25:26 -08:00
Peter Hawkins
b0b8f037b0 [JAX] Fix crash when applying jit() to a callable that is not weak-referenceable.
Fixes https://github.com/google/jax/issues/9541

PiperOrigin-RevId: 428829999
2022-02-15 11:18:05 -08:00
Adam Paszke
a14ccb99d9 Try adding support for nesting with_sharding_constraint in MANUAL xmaps
PiperOrigin-RevId: 428812729
2022-02-15 10:13:11 -08:00
jax authors
0b92052130 Merge pull request #9551 from RuffaloLavoisier:typo
PiperOrigin-RevId: 428799601
2022-02-15 09:20:41 -08:00
Peter Hawkins
2e0cfe8e42 Update the list of default CUDA capabilities used for wheel builds to match build.py. 2022-02-15 09:23:28 -05:00
Adam Paszke
b75b0c04de Limit the set of unspecified dims to those that are not explicitly converted to MANUAL.
PiperOrigin-RevId: 428740792
2022-02-15 04:00:48 -08:00
Adam Paszke
c551beda7a Add a partial_eval_jaxpr_custom_rule for xmap
Additionaly fix a bug in partial_eval rule for xmap.

PiperOrigin-RevId: 428738277
2022-02-15 03:44:19 -08:00
Lena Martens
0d9990e4f3 Run all tests with jax_traceback_filtering=off.
Context: If an AssertionError is thrown inside a test and traceback filtering
is enabled, most of the stack-trace is swallowed (due to
https://bugs.python.org/issue24959).
PiperOrigin-RevId: 428729211
2022-02-15 02:42:58 -08:00
Yash Katariya
7613d2a5df Add multi-host utilities to JAX core. Adapted from https://github.com/google-research/t5x/blob/main/t5x/multihost_utils.py
PiperOrigin-RevId: 428680123
2022-02-14 21:04:15 -08:00
jax authors
924f7aab68 Merge pull request #9485 from rsepassi:compilelog
PiperOrigin-RevId: 428622031
2022-02-14 15:29:07 -08:00
Tianjian Lu
273ea62624 [sparse] Updates bcoo_dot_general cuSparse lowering rule by adding sorted indices.
PiperOrigin-RevId: 428621454
2022-02-14 15:24:14 -08:00
jax authors
7204ac306d Merge pull request #9407 from jakevdp:type-promotion-design
PiperOrigin-RevId: 428613200
2022-02-14 14:49:17 -08:00
Jake VanderPlas
7381bbe8bb Add type promotion design doc 2022-02-14 14:16:42 -08:00
jax authors
d5694402bc Merge pull request #9564 from MichaelMarien:random-choice-docstring
PiperOrigin-RevId: 428596963
2022-02-14 13:41:06 -08:00