Jake VanderPlas
dafb88a649
jax.numpy reductions: require initial to be a scalar
...
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
jax authors
c2b7c5f132
Merge pull request #14474 from jakevdp:doc-array-methods
...
PiperOrigin-RevId: 509639140
2023-02-14 14:29:13 -08:00
Katherine Wu
59e9746552
Fix issue where HLO could not be generated for custom gradient.
...
It appears that the custom gradient function must be traced in the same context as the context in which it was defined. Fixed by shuffling around the default graphs.
PiperOrigin-RevId: 509618802
2023-02-14 13:22:30 -08:00
jax authors
a9ef98992c
Merge pull request #14472 from nouiz:shmap_jep_fixes
...
PiperOrigin-RevId: 509617771
2023-02-14 13:14:33 -08:00
Jake VanderPlas
5958bf0d2f
DOC: improve documentation for jax.Array methods
2023-02-14 13:04:27 -08:00
Jake VanderPlas
967f2118bf
DOC: improve documentation for jax.Array methods
2023-02-14 13:03:10 -08:00
jax authors
5860cfdc71
Merge pull request #14453 from jakevdp:dtypes-doc
...
PiperOrigin-RevId: 509610755
2023-02-14 12:48:40 -08:00
Peter Hawkins
33bed1e520
Opt into higher matmul precision for A100 and TPU tests.
...
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a
Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
...
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Frederic Bastien
93c93133ea
Use right fct name.
2023-02-14 11:21:16 -08:00
Frederic Bastien
d2bb1e089d
Be consistent in the index used
2023-02-14 11:21:03 -08:00
Jake VanderPlas
11e32196cc
DOC: add docs for jax.dtypes module
2023-02-14 11:18:59 -08:00
jax authors
47dca6760c
Merge pull request #14456 from jakevdp:jax-typing-public
...
PiperOrigin-RevId: 509585352
2023-02-14 11:17:05 -08:00
Frederic Bastien
673510202d
Small crash fixes
2023-02-14 11:14:26 -08:00
Peter Hawkins
658a934821
License cleanup.
...
PiperOrigin-RevId: 509563977
2023-02-14 10:07:02 -08:00
Yash Katariya
1c651f2ea4
Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
...
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Zeynep Cankara
995ef40f68
[JAX] Improve error message when jit tracer passed to a shape.
...
Adds additional debugging message to the shape explaining why the value is a tracer.
Fixes #14279
PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
John QiangZhang
4237f22939
clean up no-need exception raise.
...
PiperOrigin-RevId: 509545493
2023-02-14 09:04:19 -08:00
Jake VanderPlas
15196bc1aa
[sparse] enable bcsr_dot_general cusparse lowering
...
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
George Necula
582c042079
Implement lowering for convolutions with dynamic padding
...
PiperOrigin-RevId: 509451627
2023-02-14 00:55:45 -08:00
Peter Hawkins
0f7ffb8699
Bump minimum Python version in "contributing" docs.
...
PiperOrigin-RevId: 509385522
2023-02-13 18:05:11 -08:00
jax authors
9e01ee4d50
Merge pull request #14457 from mattjj:djax-bug-fix
...
PiperOrigin-RevId: 509377741
2023-02-13 17:28:37 -08:00
Sharad Vikram
442aa028c2
Fix xmap staging rule to handle positional semantics
...
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Jake VanderPlas
e1ff0c1d7a
Make colab_gpu.ipynb compatible with newer JAX versions
...
PiperOrigin-RevId: 509356393
2023-02-13 15:56:58 -08:00
Jake VanderPlas
7975192f92
Expose jax.typing & update docs
2023-02-13 15:53:08 -08:00
Matthew Johnson
96c558d5de
fix minor broadcasting bug
...
Co-authored-by: Adam Paszke <apaszke@google.com>
2023-02-13 15:13:13 -08:00
Yash Katariya
d0eedf7e57
Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
...
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Peter Hawkins
4a523e3d74
Minimize exported names from jax.experimental.maps.
...
Move implementation of maps to jax._src.maps.
PiperOrigin-RevId: 509309092
2023-02-13 12:57:54 -08:00
Yash Katariya
2fc64bee13
Change the axis_resources
argument of with_sharding_constraint
to shardings
to match pjit
and jit
.
...
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
jax authors
c49af18b9b
Merge pull request #14365 from jakevdp:reducers-initial
...
PiperOrigin-RevId: 509253981
2023-02-13 09:43:46 -08:00
jax authors
83b7ba2ba0
Merge pull request #14444 from jakevdp:fix-csr-lowering
...
PiperOrigin-RevId: 509241948
2023-02-13 08:59:37 -08:00
Jake VanderPlas
58323d5b40
jax.numpy reductions: better validation of initial value
2023-02-13 08:43:25 -08:00
Jake VanderPlas
ddae1d00ea
fix change to csr lowering rule
2023-02-13 08:39:05 -08:00
George Necula
2d47921a34
[shape_poly] Fix some tests for shape pol with native lowering
...
Native lowering requires that dimension variables be computable from the shapes of the arguments that are kept by lowering. We had some tests that were using dimension variables but were not using the actual inputs. Then lowering removes the inputs and it is not possible anymore to recover the values of the dimension variables at invocation time.
Here we primarily changed the tests to ensure they use not just the shape of the input but the actual value. In some cases we have disabled some testing, until
https://github.com/google/jax/issues/14437 is fixed
PiperOrigin-RevId: 509171805
2023-02-13 02:48:11 -08:00
George Necula
6b70728bc9
Fix sharding_test
...
PiperOrigin-RevId: 509048406
2023-02-12 10:26:17 -08:00
jax authors
002cdf688d
Merge pull request #14432 from gnecula:call_tf_checks_2
...
PiperOrigin-RevId: 509043710
2023-02-12 09:34:50 -08:00
George Necula
f280d31066
[jax2tf] Minor fix: remove dedundant check
2023-02-12 16:02:53 +01:00
Yash Katariya
6caaffc20c
Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
...
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Joan Puigcerver
bc1f5f1cbb
Register missing standard primitives to shard_map.py.
...
PiperOrigin-RevId: 508920824
2023-02-11 12:58:22 -08:00
Joan Puigcerver
654c1d3b2b
Fix _standard_rep_rule in shard_map.py when in_rep is empty.
...
set.intersection() with no arguments (in_rep is empty) raises an exception.
PiperOrigin-RevId: 508910287
2023-02-11 11:09:47 -08:00
jax authors
1089992017
Merge pull request #14424 from stellaraccident:devenh
...
PiperOrigin-RevId: 508902521
2023-02-11 09:50:50 -08:00
Peter Hawkins
612a940160
Minimize the set of names exported from jax.experimental.pjit.
...
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a
[Rollback] Convert _arrays to return PyArray instead of PyBuffer.
...
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Stella Laurenzo
c1e13bdf3f
A few developer workflow enhancements for working with jaxlib.
...
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:
* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
2023-02-10 21:03:21 -08:00
Peter Hawkins
61da781174
[JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array.
...
PiperOrigin-RevId: 508822404
2023-02-10 20:56:34 -08:00
jax authors
1bdcd5e138
Merge pull request #14415 from jakevdp:bcsr-matmul
...
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
Roy Frostig
a262314934
prune unintended exports from jax.interpreters.batching
...
PiperOrigin-RevId: 508784928
2023-02-10 16:47:28 -08:00
Skye Wanderman-Milne
e54858522c
Add back loading TPU plugin for older jaxlib versions.
...
This was removed in 668b82d529
.
PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
jax authors
26ddf3b571
Merge pull request #14419 from jakevdp:spsolve-cpu-lowering
...
PiperOrigin-RevId: 508777573
2023-02-10 16:16:05 -08:00
Jake VanderPlas
de8a77a3eb
[sparse] implement BCSR.__matmul__
2023-02-10 16:11:57 -08:00