9462 Commits

Author SHA1 Message Date
Yash Katariya
fbdeff0a80 Update the workspace file
PiperOrigin-RevId: 404076864
jaxlib-v0.1.73
2021-10-18 15:26:32 -07:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
jax authors
6c833a16a1 Merge pull request #8240 from jakevdp:dtype-annotations
PiperOrigin-RevId: 404066753
2021-10-18 14:48:41 -07:00
Peter Hawkins
95f47074da Remove xla_bridge.{constant, register_constant_handler, _python_scalar_constant} from API.
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.

PiperOrigin-RevId: 404051961
2021-10-18 13:56:58 -07:00
Yash Katariya
93fe3ab492 Replace _ with - because wheel.py normalizes it to .
PiperOrigin-RevId: 404049619
2021-10-18 13:47:43 -07:00
Yash Katariya
0f0bfcaef5 Skip test_bcoo_spdot_general because its failing OSS tests
PiperOrigin-RevId: 404042363
2021-10-18 13:20:09 -07:00
jax authors
8a261f04d5 Merge pull request #8261 from hawkinsp:real
PiperOrigin-RevId: 404013628
2021-10-18 11:28:12 -07:00
jax authors
9eb06800fe Merge pull request #8260 from hawkinsp:unpack
PiperOrigin-RevId: 404001515
2021-10-18 10:50:44 -07:00
Peter Hawkins
8c3b212dd6 Improve real type conversion in a couple more places. 2021-10-18 13:50:11 -04:00
jax authors
ef9ce1c39d Merge pull request #8259 from hawkinsp:xlabuilder
PiperOrigin-RevId: 404001408
2021-10-18 10:46:45 -07:00
Peter Hawkins
051375976a Remove unused backward compatibility code in cusolver.py.
Simplify implementation of _real_type in passing.
2021-10-18 13:27:10 -04:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
jax authors
391cafb0e5 Merge pull request #8225 from jakevdp:bcoo-padding
PiperOrigin-RevId: 403978219
2021-10-18 09:19:33 -07:00
Jake VanderPlas
f424a90c71 [sparse]: change bcoo pad values to use OOB indices 2021-10-18 08:57:57 -07:00
Marc van Zee
46b9653e28 Factors out all enable_xla=False ops into a separate library impl_no_xla.py.
Also creates a new map tf_impl_no_xla containing the functions that should be called when XLA is disabled, simplifying the logic in the regular converter functions.

PiperOrigin-RevId: 403956433
2021-10-18 07:38:11 -07:00
Yash Katariya
e6e81ba885 Add Cuda 11.4 with cudnn 8.2 and cudnn 8.0.5 release builds
PiperOrigin-RevId: 403661187
2021-10-16 16:13:43 -07:00
Peter Hawkins
267a4ca4cb Reenable a test that was disabled due to an (apparently fixed) LLVM bug.
PiperOrigin-RevId: 403623977
2021-10-16 10:34:36 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
jax authors
69d7a813e7 Merge pull request #8236 from jakevdp:fix-bincount
PiperOrigin-RevId: 403514221
2021-10-15 18:39:20 -07:00
jax authors
875efeebe7 Merge pull request #8242 from mattjj:remat-fix6
PiperOrigin-RevId: 403504234
2021-10-15 17:31:06 -07:00
jax authors
40103f6a71 Merge pull request #8241 from froystig:xla-computation-in-avals
PiperOrigin-RevId: 403500602
2021-10-15 17:08:57 -07:00
Matthew Johnson
89606c2c35 remat: fix regression of broke calling convention
In #7631 we made `_partial_eval_jaxpr_custom` follow a convention: drop
unit outputs from the known jaxprs it returned:

3377abaef4/jax/interpreters/partial_eval.py (L937)

The caller needed to compensate for that, e.g. by dropping the
corresponding binders in an outer jaxpr eqn representing an application
of that inner jaxpr. But that logic was written in terms of checking for
dropped outputs in the outer jaxpr (since units are typically not
consumed downstream):

https://github.com/google/jax/pull/8227/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686L981

That worked (or at least we never noticed a failure, though now it seems
sketchy...) with the 'classic' `jax.checkpoint` / `jax.remat`
implementation, before #8191, because of how that implementation relied
on tracing-based partial evaluation, which would detect and mark dropped
outputs in the outer jaxpr as part of jaxpr formation.

But then in #8191 we no longer marked dropvars in the same way. That led
to assertion failures, and #8227 attempted to fix those. That fix made
sense with the new remat implementation, but not the old one! (In the
intervening period I forgot about this unit-dropping convention...)

The fix here is not to rely on dropvars but to more directly encode the
convention that _partial_eval_jaxpr_custom drops unit outputs in the
known jaxpr it produces.
2021-10-15 17:00:44 -07:00
Roy Frostig
f68c0a42c7 drop unused arguments in the jit AOT call path 2021-10-15 16:12:05 -07:00
Jake VanderPlas
afe7e194e9 Fix inaccurate type annotations 2021-10-15 15:35:43 -07:00
jax authors
3377abaef4 Merge pull request #8237 from hawkinsp:argsort
PiperOrigin-RevId: 403445597
2021-10-15 12:54:52 -07:00
Peter Hawkins
af5d3675dd Change default kind for jnp.argsort to stable. Warn if anything other than stable is passed. 2021-10-15 15:43:53 -04:00
Jake VanderPlas
7a2686f366 jnp.bincount: fix corner cases 2021-10-15 12:31:17 -07:00
jax authors
613dc4bffa Merge pull request #8234 from apaszke:all-gather-batching
PiperOrigin-RevId: 403414516
2021-10-15 10:51:33 -07:00
jax authors
a64ce45c8e Merge pull request #8222 from mattjj:document-vmap-axis-name
PiperOrigin-RevId: 403412220
2021-10-15 10:43:00 -07:00
jax authors
3ee94d93c2 Merge pull request #7803 from jakevdp:jnp-take-validation
PiperOrigin-RevId: 403408483
2021-10-15 10:31:29 -07:00
jax authors
8f1d7beace Merge pull request #8217 from LenaMartens:changelist/403115357
PiperOrigin-RevId: 403402974
2021-10-15 10:10:06 -07:00
jax authors
239662efb7 Merge pull request #8228 from mattjj:remat-fix5
PiperOrigin-RevId: 403397740
2021-10-15 09:48:47 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
jax authors
267e1ec5a4 Merge pull request #8223 from jakevdp:unique-speed
PiperOrigin-RevId: 403394797
2021-10-15 09:34:56 -07:00
Jake VanderPlas
a3a6a5b137 jnp.unique: improve efficiency & consolidate implementation 2021-10-15 07:59:40 -07:00
Adam Paszke
49d9affce0 Enable batcher and batched collective rules for tiled all gathers
Fixes #8221.
2021-10-15 14:37:38 +00:00
Marc van Zee
aaf3bb789e Improves support for conv_general_dilated in JAX for models running on the web or mobile through TFLite and TFjs.
### Before this change

Prior to my change, there were a number of limitations to using convolutions for web/mobile:

* No strides other than (1,1) could be used.
* Padding was only possible for values ["VALID", "SAME"]
* Transposed convolutions were unsupported
* Depthwise convolutions were unsupported
* Input could only be provided in a very specific format, which prevented many use cases.

### After this change

After this change, we now can support the following cases:
* Any strides size can be used
* Any padding can be used (VALID, SAME, or custom numbers)
* Transposed convolutions are supported
* Depthwise convolutions are supported
* Input can be provided in any format.

### Impact on examples

Before, most of the Flax examples using convolutions were failing.
After, all convolutions are converting successfully.

PiperOrigin-RevId: 403302738
2021-10-15 01:00:48 -07:00
Yash Katariya
0578ba68f4 Docker file for Cuda 11.4 built with Cudnn 8.0.5
PiperOrigin-RevId: 403283584
2021-10-14 22:51:23 -07:00
Matthew Johnson
8bfa5eec8c fix dce logic 2021-10-14 20:43:40 -07:00
jax authors
804b0b39e6 Merge pull request #8227 from mattjj:remat-fix4
PiperOrigin-RevId: 403253854
2021-10-14 19:48:13 -07:00
Matthew Johnson
d1f0c60b7b keep dropvar binders in call_partial_eval_custom_rule
The dropvars indicate that these binders/outputs aren't used in the
outer jaxpr and so they could be dropped, but to drop the binders would
require also editing the called jaxpr to be consistent. For completeness
that editing could involve DCE, which in turn can affect the jaxpr's
inputs.

Instead of doing that bookkeeping, we can just keep the dropvars.
There's a DCE pass to follow in the remat primitive's partial eval rule
which will clean these up.

(This commit also contains unrelated tweaks to comments and strings.)
2021-10-14 19:27:38 -07:00
jax authors
fcbbd29dab Merge pull request #8215 from jakevdp:unique-multi-fill
PiperOrigin-RevId: 403179982
2021-10-14 13:59:28 -07:00
Yash Katariya
ac0796048f Move cuda .py files to :gpu_support so that if :gpu_support is not present, then internal jaxlib will act like a CPU jaxlib even if --config=cuda is specified.
PiperOrigin-RevId: 403170945
2021-10-14 13:20:01 -07:00
Matthew Johnson
584aa13360 document axis_name in the vmap docstring
fixes #8220
2021-10-14 13:09:24 -07:00
Jake VanderPlas
1bafdb6d7e fix repr() of jit-compiled functions
PiperOrigin-RevId: 403157400
2021-10-14 12:25:29 -07:00
Jake VanderPlas
c5a8c5c826 jnp.unique: allow fill_value to be a slice 2021-10-14 12:07:29 -07:00
jax authors
26549bdca0 Merge pull request #8219 from mattjj:fix-saved-residuals-utility
PiperOrigin-RevId: 403147655
2021-10-14 11:46:05 -07:00
Matthew Johnson
297f79c1de make saved_residuals utility work w/ literals 2021-10-14 11:32:09 -07:00
Peter Hawkins
e0d23a7ff0 Improve performance of JIT dispatch when output arity is 1.
Building an output tuple has a non-zero cost on TPU. We can avoid it in the output arity 1 case.

PiperOrigin-RevId: 403142765
2021-10-14 11:28:03 -07:00
Lena Martens
f5d8d4bc4e Replace loop with map in RBG batching_rule. 2021-10-14 18:01:03 +01:00