14819 Commits

Author SHA1 Message Date
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
jax authors
c2b7c5f132 Merge pull request #14474 from jakevdp:doc-array-methods
PiperOrigin-RevId: 509639140
2023-02-14 14:29:13 -08:00
Katherine Wu
59e9746552 Fix issue where HLO could not be generated for custom gradient.
It appears that the custom gradient function must be traced in the same context as the context in which it was defined. Fixed by shuffling around the default graphs.

PiperOrigin-RevId: 509618802
2023-02-14 13:22:30 -08:00
jax authors
a9ef98992c Merge pull request #14472 from nouiz:shmap_jep_fixes
PiperOrigin-RevId: 509617771
2023-02-14 13:14:33 -08:00
Jake VanderPlas
5958bf0d2f DOC: improve documentation for jax.Array methods 2023-02-14 13:04:27 -08:00
Jake VanderPlas
967f2118bf DOC: improve documentation for jax.Array methods 2023-02-14 13:03:10 -08:00
jax authors
5860cfdc71 Merge pull request #14453 from jakevdp:dtypes-doc
PiperOrigin-RevId: 509610755
2023-02-14 12:48:40 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Frederic Bastien
93c93133ea Use right fct name. 2023-02-14 11:21:16 -08:00
Frederic Bastien
d2bb1e089d Be consistent in the index used 2023-02-14 11:21:03 -08:00
Jake VanderPlas
11e32196cc DOC: add docs for jax.dtypes module 2023-02-14 11:18:59 -08:00
jax authors
47dca6760c Merge pull request #14456 from jakevdp:jax-typing-public
PiperOrigin-RevId: 509585352
2023-02-14 11:17:05 -08:00
Frederic Bastien
673510202d Small crash fixes 2023-02-14 11:14:26 -08:00
Peter Hawkins
658a934821 License cleanup.
PiperOrigin-RevId: 509563977
2023-02-14 10:07:02 -08:00
Yash Katariya
1c651f2ea4 Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Zeynep Cankara
995ef40f68 [JAX] Improve error message when jit tracer passed to a shape.
Adds additional debugging message to the shape explaining why the value is a tracer.

Fixes #14279

PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
John QiangZhang
4237f22939 clean up no-need exception raise.
PiperOrigin-RevId: 509545493
2023-02-14 09:04:19 -08:00
Jake VanderPlas
15196bc1aa [sparse] enable bcsr_dot_general cusparse lowering
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
George Necula
582c042079 Implement lowering for convolutions with dynamic padding
PiperOrigin-RevId: 509451627
2023-02-14 00:55:45 -08:00
Peter Hawkins
0f7ffb8699 Bump minimum Python version in "contributing" docs.
PiperOrigin-RevId: 509385522
2023-02-13 18:05:11 -08:00
jax authors
9e01ee4d50 Merge pull request #14457 from mattjj:djax-bug-fix
PiperOrigin-RevId: 509377741
2023-02-13 17:28:37 -08:00
Sharad Vikram
442aa028c2 Fix xmap staging rule to handle positional semantics
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Jake VanderPlas
e1ff0c1d7a Make colab_gpu.ipynb compatible with newer JAX versions
PiperOrigin-RevId: 509356393
2023-02-13 15:56:58 -08:00
Jake VanderPlas
7975192f92 Expose jax.typing & update docs 2023-02-13 15:53:08 -08:00
Matthew Johnson
96c558d5de fix minor broadcasting bug
Co-authored-by: Adam Paszke <apaszke@google.com>
2023-02-13 15:13:13 -08:00
Yash Katariya
d0eedf7e57 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Peter Hawkins
4a523e3d74 Minimize exported names from jax.experimental.maps.
Move implementation of maps to jax._src.maps.

PiperOrigin-RevId: 509309092
2023-02-13 12:57:54 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
jax authors
c49af18b9b Merge pull request #14365 from jakevdp:reducers-initial
PiperOrigin-RevId: 509253981
2023-02-13 09:43:46 -08:00
jax authors
83b7ba2ba0 Merge pull request #14444 from jakevdp:fix-csr-lowering
PiperOrigin-RevId: 509241948
2023-02-13 08:59:37 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Jake VanderPlas
ddae1d00ea fix change to csr lowering rule 2023-02-13 08:39:05 -08:00
George Necula
2d47921a34 [shape_poly] Fix some tests for shape pol with native lowering
Native lowering requires that dimension variables be computable from the shapes of the arguments that are kept by lowering. We had some tests that were using dimension variables but were not using the actual inputs. Then lowering removes the inputs and it is not possible anymore to recover the values of the dimension variables at invocation time.

Here we primarily changed the tests to ensure they use not just the shape of the input but the actual value. In some cases we have disabled some testing, until
https://github.com/google/jax/issues/14437 is fixed

PiperOrigin-RevId: 509171805
2023-02-13 02:48:11 -08:00
George Necula
6b70728bc9 Fix sharding_test
PiperOrigin-RevId: 509048406
2023-02-12 10:26:17 -08:00
jax authors
002cdf688d Merge pull request #14432 from gnecula:call_tf_checks_2
PiperOrigin-RevId: 509043710
2023-02-12 09:34:50 -08:00
George Necula
f280d31066 [jax2tf] Minor fix: remove dedundant check 2023-02-12 16:02:53 +01:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Joan Puigcerver
bc1f5f1cbb Register missing standard primitives to shard_map.py.
PiperOrigin-RevId: 508920824
2023-02-11 12:58:22 -08:00
Joan Puigcerver
654c1d3b2b Fix _standard_rep_rule in shard_map.py when in_rep is empty.
set.intersection() with no arguments (in_rep is empty) raises an exception.

PiperOrigin-RevId: 508910287
2023-02-11 11:09:47 -08:00
jax authors
1089992017 Merge pull request #14424 from stellaraccident:devenh
PiperOrigin-RevId: 508902521
2023-02-11 09:50:50 -08:00
Peter Hawkins
612a940160 Minimize the set of names exported from jax.experimental.pjit.
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a [Rollback] Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Stella Laurenzo
c1e13bdf3f A few developer workflow enhancements for working with jaxlib.
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:

* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
2023-02-10 21:03:21 -08:00
Peter Hawkins
61da781174 [JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array.
PiperOrigin-RevId: 508822404
2023-02-10 20:56:34 -08:00
jax authors
1bdcd5e138 Merge pull request #14415 from jakevdp:bcsr-matmul
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
Roy Frostig
a262314934 prune unintended exports from jax.interpreters.batching
PiperOrigin-RevId: 508784928
2023-02-10 16:47:28 -08:00
Skye Wanderman-Milne
e54858522c Add back loading TPU plugin for older jaxlib versions.
This was removed in 668b82d529.

PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
jax authors
26ddf3b571 Merge pull request #14419 from jakevdp:spsolve-cpu-lowering
PiperOrigin-RevId: 508777573
2023-02-10 16:16:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00