10550 Commits

Author SHA1 Message Date
jax authors
74fb420130 Merge pull request #9510 from mattjj:checkify-custom-vjp
PiperOrigin-RevId: 427546471
2022-02-09 12:48:00 -08:00
Matthew Johnson
d9270b242d [checkify] add custom_vjp support 2022-02-09 12:31:16 -08:00
jax authors
b82ef91f42 Merge pull request #9509 from mattjj:checkify-custom-jvp
PiperOrigin-RevId: 427541020
2022-02-09 12:25:22 -08:00
jax authors
61b884b0d4 Merge pull request #9494 from superbobry:lax_numpy-any
PiperOrigin-RevId: 427539082
2022-02-09 12:17:59 -08:00
Matthew Johnson
4ce749e681 [checkify] handle custom_jvp 2022-02-09 12:12:58 -08:00
Yash Katariya
02b8ce3373 Use thread local positional semantics and change thread_resources to subclass from threading.local.
Both these fixes makes sure that you can compile pjit in multiple threads.

PiperOrigin-RevId: 427517953
2022-02-09 11:19:56 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
jax authors
e219ea08f5 Merge pull request #9204 from LenaMartens:changelist/420726115
PiperOrigin-RevId: 427470746
2022-02-09 07:54:17 -08:00
Matthew Johnson
d57990ecf9 improve pjit in/out_axis_resources pytree errors
This is an application of the utilities in #9372.
2022-02-08 16:23:15 -08:00
Peter Hawkins
82d8261308 Speed up source location computation when lowering a jaxpr to HLO/MHLO.
Speed up source_info_util.user_frames by using a newly refactored Traceback.raw_frames() attribute. Since we are interested only in one frame, it's best to avoid doing wasted work on all the frames we are going to ignore.

Change traceback.raw_frames() to return the transpose of what it previously returned because it means we only need to build 3 Python objects, rather than n + 1 Python objects for n frames.

PiperOrigin-RevId: 427320674
2022-02-08 16:17:40 -08:00
Yash Katariya
0a060cf183 Catch the device buffer order with the expected global_mesh.local_devices order that we should get. This is to make sure that we don't get some cryptic message from XLA and catch it in GDA itself.
PiperOrigin-RevId: 427315736
2022-02-08 15:58:20 -08:00
jax authors
4e8043f2d1 Merge pull request #9461 from MichaelMarien:quantile-tuple-axis
PiperOrigin-RevId: 427313122
2022-02-08 15:49:54 -08:00
jax authors
e8ec9570dd Merge pull request #9471 from jakevdp:generic
PiperOrigin-RevId: 427310192
2022-02-08 15:39:58 -08:00
jax authors
bce332bf83 Merge pull request #9481 from jakevdp:jax-enable-checks
PiperOrigin-RevId: 427310155
2022-02-08 15:34:51 -08:00
Sergei Lebedev
0fe377ce42 Added an explicit Any return type to lax_numpy.ndarray methods
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.

As a hand-wavy aside, in a type stub

    def f(): ...

could be treated equivalently to

    def f() -> Any: ...

because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
2022-02-08 22:18:16 +00:00
lenamartens
4d0db5d975 Fix build and address suggestion 2022-02-08 20:57:51 +00:00
Lena Martens
b2cf12aa7e
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-02-08 20:23:40 +00:00
Lena Martens
010eb82ad0 Rename wrapper functions to always refer to the JAX api function.
eg. batched_fun -> vmap_fun
2022-02-08 20:10:39 +00:00
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Ryan Sepassi
be12038b97
Merge branch 'main' into compilelog 2022-02-08 11:21:56 -08:00
Lena Martens
0042edb5f4 Checkify: rename some symbols and add some docstrings. 2022-02-08 17:40:04 +00:00
jax authors
fbda1a650f Merge pull request #9489 from hawkinsp:sda
PiperOrigin-RevId: 427210907
2022-02-08 09:28:14 -08:00
jax authors
c137aadad7 Merge pull request #9488 from pnkraemer:jet-primitive-dynamic-slice
PiperOrigin-RevId: 427210782
2022-02-08 09:23:27 -08:00
Peter Hawkins
5679fedd2c Fix missing handler when lexically capturing a ShardedDeviceArray when MLIR enabled. 2022-02-08 09:51:57 -05:00
Nicholas Krämer
c07a0f1139 Add test and jet-primitive for dynamic_slice 2022-02-08 13:28:41 +01:00
jax authors
44c6c055d3 Merge pull request #9442 from mattjj:zeros
PiperOrigin-RevId: 427078926
2022-02-07 19:17:57 -08:00
jax authors
3c1a1ca6f0 Merge pull request #9419 from hawkinsp:versioning
PiperOrigin-RevId: 427073382
2022-02-07 18:31:52 -08:00
Ryan Sepassi
085e65863a Update compilation cache logging to log individual hashes as well as cumulative hash 2022-02-07 17:17:31 -08:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Jake VanderPlas
760f309fb5 Add jax.numpy.generic 2022-02-07 14:56:39 -08:00
Jake VanderPlas
f2222bb1cf CI: error if docstring rewrite fails 2022-02-07 14:43:00 -08:00
Peter Hawkins
287c476eec Cache traceback to MLIR location conversion.
Finding the user frame in a traceback is something we do for every jaxpr equation, and it shows up in profiles. We expect a reasonable amount of locality, e.g., many lines of code with similar provenance appearing together, so this seems like a place for a small LRU cache.

PiperOrigin-RevId: 427020947
2022-02-07 14:40:46 -08:00
Peter Hawkins
f539c9b9bd Hoist construction of predicates out of cond batching rule.
Avoids building the "which path are we following" predicate once for each input.

PiperOrigin-RevId: 427012972
2022-02-07 14:13:03 -08:00
Peter Hawkins
e7032fe910 Remove unnecessary dtype canonicalization from jax.core.raise_to_shaped.
I noticed this in passing while working on https://github.com/google/jax/pull/9468. It seems strange to me that we would change the dtype when raising a ShapedArray to a ShapedArray, and indeed it seems not to be necessary.

PiperOrigin-RevId: 427011028
2022-02-07 14:06:13 -08:00
jax authors
5648f768ff Merge pull request #9477 from hawkinsp:doc2
PiperOrigin-RevId: 427002408
2022-02-07 13:34:25 -08:00
Peter Hawkins
465b593293 Update scipy intersphinx inventory for SciPy 1.8.0.
According to https://github.com/scipy/scipy/issues/14267 the SciPy docs seems to have moved.
2022-02-07 16:19:46 -05:00
Adam Paszke
296832e891 Use aval_out to construct a sharding spec in shard to full
The shard's dimensions might be too small and might trigger asserts, even though
the shape has no influence on sharding specs.

PiperOrigin-RevId: 426955706
2022-02-07 10:39:41 -08:00
Adam Paszke
42cd7ed1d0 Allow nesting MANUAL-style xmaps in pjits
PiperOrigin-RevId: 426955137
2022-02-07 10:35:06 -08:00
jax authors
1bc8ee1df7 Merge pull request #9441 from jakevdp:test-all-types
PiperOrigin-RevId: 426937874
2022-02-07 09:31:42 -08:00
Jake VanderPlas
0b4b0fd07b tests: add unsigned ints to all_dtypes 2022-02-07 08:59:44 -08:00
jax authors
5956f3940b Merge pull request #9468 from hawkinsp:opt
PiperOrigin-RevId: 426929169
2022-02-07 08:52:53 -08:00
jax authors
524ad4b2c3 Merge pull request #9437 from jakevdp:doc-toc
PiperOrigin-RevId: 426909827
2022-02-07 07:20:50 -08:00
jax authors
9c248d5a9e Merge pull request #9470 from hawkinsp:scipy
PiperOrigin-RevId: 426909482
2022-02-07 07:16:13 -08:00
Peter Hawkins
942581d136 Fix test failure in line search test due to scipy 1.8 release. 2022-02-07 14:49:54 +00:00
Peter Hawkins
d386b1f7b3 Small speedups to core.raise_to_shaped().
Avoid forming a new ShapedArray if we already have a ShapedArray.

Don't use the slower safe map() when canonicalizing shapes. We're going
to form a tuple anyway.

Before:
```
In [1]: import numpy as np ; from jax import core, numpy as jnp
In [2]: x = core.ShapedArray((100,100), jnp.float32)
In [3]: %timeit core.raise_to_shaped(x)
4.11 µs ± 30.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

After:
```
In [1]: import numpy as np ; from jax import core, numpy as jnp
In [2]: x = core.ShapedArray((100,100), jnp.float32)
In [3]: %timeit core.raise_to_shaped(x)
207 ns ± 0.131 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
2022-02-07 14:26:03 +00:00
jax authors
89be6d29c5 Merge pull request #9395 from jakevdp:faster-numpy-test
PiperOrigin-RevId: 426850791
2022-02-07 01:29:21 -08:00
jax authors
7a6986c4c8 [JAX] Update the ANN document.
Deletes the documentation that explains the algorithm.
I don't think it is the necessary detail for users.
We'll write a paper to explain it in detail very soon.

PiperOrigin-RevId: 426546480
2022-02-04 20:03:27 -08:00
James Bradbury
5dd1c75969 Add batch_axis to variance scaling initializers
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Jake VanderPlas
bd18ded481 NumpyUfuncTests: speed up test generation 2022-02-04 15:22:16 -08:00
Jake VanderPlas
ea8817b329 DOC: move experimental APIs to their own pages 2022-02-04 14:40:34 -08:00