9356 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
0072c32546 Update CHANGELOG and verson numbers for jaxlib 0.1.72 release 2021-10-12 17:37:29 -07:00
jax authors
d8b7bd54be Merge pull request #8181 from skye:workspace
PiperOrigin-RevId: 402632543
jaxlib-v0.1.72
2021-10-12 12:51:39 -07:00
Skye Wanderman-Milne
49a3b37a44 Update WORKSPACE for jaxlib 0.1.72 release 2021-10-12 12:46:20 -07:00
jax authors
ef3e3b5593 Rollback fix repr() of jit-compiled functions
PiperOrigin-RevId: 402621480
2021-10-12 12:05:10 -07:00
jax authors
7fa6b1b5fa Merge pull request #8160 from Jakob-Unfried:main
PiperOrigin-RevId: 402582053
2021-10-12 09:24:24 -07:00
jax authors
40ee7f2ea8 Merge pull request #8164 from jakevdp:bcoo-dedupe-pad
PiperOrigin-RevId: 402565496
2021-10-12 08:12:33 -07:00
jax authors
ccc8b720cb Merge pull request #8153 from apaszke:while-loop-batching
PiperOrigin-RevId: 402565162
2021-10-12 08:08:27 -07:00
Jakob Unfried
bec943cee0 fix fori_loop and scan when trivial and with disable_jit 2021-10-12 16:22:51 +02:00
jax authors
a8ce40be94 Merge pull request #7989 from gnecula:remat_docstring
PiperOrigin-RevId: 402551996
2021-10-12 06:57:58 -07:00
Adam Paszke
b8dc8ca004 Fix while loop batching rule for loops with constant bodies
The previous implementation failed to reach a fixpoint when the body was
ignoring the carry and was returning an unbatched constant. Ensuring
that the result carry is at least as batched as the input carry should
fix the issue.
2021-10-12 11:19:24 +00:00
jax authors
cfa0f78bed Merge pull request #8140 from jakevdp:nanstd-grad
PiperOrigin-RevId: 402431304
2021-10-11 17:25:17 -07:00
Jake VanderPlas
21a6f5c2f2 Fix repr() of jit-compiled functions (Fixes #8141)
PiperOrigin-RevId: 402402875
2021-10-11 15:11:25 -07:00
jax authors
f56a378542 Merge pull request #8162 from froystig:gitignore-virtualenv
PiperOrigin-RevId: 402372106
2021-10-11 13:01:56 -07:00
Jake VanderPlas
e30ca031e1 [sparse] fix padding bug in BCOO._dedupe() 2021-10-11 12:48:27 -07:00
Roy Frostig
f9bbab3f78 add common virtualenv directories to .gitignore 2021-10-11 11:06:33 -07:00
jax authors
d17633413d Merge pull request #8144 from jakevdp:setdiff1d-size
PiperOrigin-RevId: 402331005
2021-10-11 10:19:44 -07:00
Jake VanderPlas
9ea8ce9b58 BUG: fix gradients for nanvar & nanstd 2021-10-11 09:29:22 -07:00
Jake VanderPlas
348a098f9e jax.numpy: clarify extra docs about the size argument 2021-10-11 09:27:03 -07:00
Jake VanderPlas
2944881977 jnp.setdiff1d: add optional size and fill_value arguments 2021-10-11 09:26:08 -07:00
Marc van Zee
161363da69 Implements a modular and easily extensible evaluation framework for both TFLite and TFjs. The evaluation framework has the following features:
*  It is easy to add new Modules of examples since each Module is specified using a few lines of code (see `examples.py`).

* It is easy to add new converters since each converter is represented as a function (see `converters.py`). For instance, we could add the MLIR-based converter that the TFLite team is currently working on.

* The framework outputs a Markdown table (see `README.md`).

The framework has the following limitations:

* We only evaluate whether a Module converts, we do not compare any outputs between the converted model and the original model. This will require more effort, and it seems like we can do this as a follow-up if necessary (once a good fraction of ops are converted).

* If an example is missing multiple ops, then only the first missing op is reported. We could improve this by implementing mocked versions of non-working ops, which only output the right shapes. We could also consider doing this as a follow-up.

PiperOrigin-RevId: 402287865
2021-10-11 07:12:58 -07:00
jax authors
a47119d313 Merge pull request #8154 from apaszke:jax2tf-pjit-multihost
PiperOrigin-RevId: 402279083
2021-10-11 06:31:38 -07:00
Adam Paszke
dad29d343d Disallow jax2tf translation of multi-process pjits
I'm pretty sure it doesn't handle the local/global shape boundary
correctly which likely leads to very confusing errors on the TF side.
2021-10-11 11:44:06 +00:00
jax authors
92819f7b4b Merge pull request #8143 from jakevdp:union1d-fill-value
PiperOrigin-RevId: 402238900
2021-10-11 02:08:30 -07:00
jax authors
cd3d37f4d1 Merge pull request #8124 from gnecula:tf_poly
PiperOrigin-RevId: 402232971
2021-10-11 01:28:06 -07:00
George Necula
a75fb371f2 [jax2tf] Improved handling of getitem for shape polymorphism
* give an error for NumPy indexing with slices when the elements
  of the slices are not constant. This check existed, but was
  throing an error when the elements are dimension polynomials.
* give an error for NumPy indexing with slices when the dimension
  size is not constant.
* Improvements in the handling of enable_xla=False for shape
  polymorphism.
* Added test cases for the above.
2021-10-11 09:14:57 +02:00
George Karpenkov
f2aef25fba Use variadic reduce on GPU for argmax/argmin
PiperOrigin-RevId: 401923051
2021-10-08 21:58:56 -07:00
Jake VanderPlas
a4241a2aa3 jnp.union1d: add optional fill_value argument 2021-10-08 15:18:25 -07:00
jax authors
da1caf5d6d Merge pull request #7997 from google:aot
PiperOrigin-RevId: 401837898
2021-10-08 13:03:40 -07:00
jax authors
c353acb374 Merge pull request #8134 from jakevdp:sda-buffer-array
PiperOrigin-RevId: 401820091
2021-10-08 11:43:36 -07:00
Roy Frostig
0c75f52fa8 ahead-of-time lowering and compilation for jit 2021-10-08 10:54:45 -07:00
Roy Frostig
75468c7495 factor out jit input preparation 2021-10-08 10:54:45 -07:00
Jake VanderPlas
486aac949a jnp.array: handle raw device buffers 2021-10-08 10:41:43 -07:00
jax authors
dd5df5a562 Merge pull request #8121 from jakevdp:unique-fill-value
PiperOrigin-RevId: 401785306
2021-10-08 09:10:05 -07:00
George Necula
3938018228 Applied review suggestsions 2021-10-08 10:11:31 +02:00
jax authors
f9bead4b75 Merge pull request #8135 from google:default-rng
PiperOrigin-RevId: 401692501
2021-10-07 23:06:34 -07:00
jax authors
2028087a04 Merge pull request #8137 from mattjj:shaped-array-len
PiperOrigin-RevId: 401690179
2021-10-07 22:51:02 -07:00
Matthew Johnson
482e41d796 remove ShapedArray.__len__
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
2021-10-07 22:04:16 -07:00
jax authors
f8ec664997 Merge pull request #8136 from mattjj:rbg-improvements
PiperOrigin-RevId: 401683140
2021-10-07 21:53:40 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Matthew Johnson
022cb8c0fc rbg_split and rbg_fold_in: use vmap for fewer HLOs 2021-10-07 21:19:06 -07:00
jax authors
8af2cf12c1 Merge pull request #8133 from mattjj:dlpack-error-test-update
PiperOrigin-RevId: 401678825
2021-10-07 21:18:23 -07:00
jax authors
b002bc178e Merge pull request #8123 from mattjj:fix-rng-bit-generator-again
PiperOrigin-RevId: 401673628
2021-10-07 20:39:24 -07:00
Matthew Johnson
ef710ec1f6 update test of dlpack error message 2021-10-07 19:14:13 -07:00
Matthew Johnson
634d252bb3 improvements to RBG PRNG
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
   threefry2x32 for split and fold_in, while the latter uses untested
   heuristics based on calling rng_bit_generator itself as a kind of
   hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
   sequences from rng_bit_generator (10x iterations) which may be useful on
   some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
   'hash function' in fold_in

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
2021-10-07 18:59:13 -07:00
jax authors
efa5edfd39 Merge pull request #8091 from eelregit:check-eq-float0
PiperOrigin-RevId: 401494757
2021-10-07 06:23:59 -07:00
jax authors
c877f8bcd2 Merge pull request #8094 from sracaniere:patch-2
PiperOrigin-RevId: 401485340
2021-10-07 05:18:04 -07:00
jax authors
b0b60d6293 Merge pull request #8116 from jakevdp:at-docs
PiperOrigin-RevId: 401447782
2021-10-07 01:16:53 -07:00
jax authors
8f0589f085 Merge pull request #8117 from LenaMartens:changelist/400933831
PiperOrigin-RevId: 401418826
2021-10-06 21:54:20 -07:00
Yash Katariya
bfbdfa87e7 Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.
PiperOrigin-RevId: 401402336
2021-10-06 19:51:35 -07:00
Jake VanderPlas
0b93c46c71 jnp.unique: add fill_value for when size is not None 2021-10-06 16:28:36 -07:00