5826 Commits

Author SHA1 Message Date
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
jax authors
6c833a16a1 Merge pull request #8240 from jakevdp:dtype-annotations
PiperOrigin-RevId: 404066753
2021-10-18 14:48:41 -07:00
Peter Hawkins
95f47074da Remove xla_bridge.{constant, register_constant_handler, _python_scalar_constant} from API.
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.

PiperOrigin-RevId: 404051961
2021-10-18 13:56:58 -07:00
Peter Hawkins
8c3b212dd6 Improve real type conversion in a couple more places. 2021-10-18 13:50:11 -04:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
Jake VanderPlas
f424a90c71 [sparse]: change bcoo pad values to use OOB indices 2021-10-18 08:57:57 -07:00
Marc van Zee
46b9653e28 Factors out all enable_xla=False ops into a separate library impl_no_xla.py.
Also creates a new map tf_impl_no_xla containing the functions that should be called when XLA is disabled, simplifying the logic in the regular converter functions.

PiperOrigin-RevId: 403956433
2021-10-18 07:38:11 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
jax authors
69d7a813e7 Merge pull request #8236 from jakevdp:fix-bincount
PiperOrigin-RevId: 403514221
2021-10-15 18:39:20 -07:00
jax authors
875efeebe7 Merge pull request #8242 from mattjj:remat-fix6
PiperOrigin-RevId: 403504234
2021-10-15 17:31:06 -07:00
Matthew Johnson
89606c2c35 remat: fix regression of broke calling convention
In #7631 we made `_partial_eval_jaxpr_custom` follow a convention: drop
unit outputs from the known jaxprs it returned:

3377abaef4/jax/interpreters/partial_eval.py (L937)

The caller needed to compensate for that, e.g. by dropping the
corresponding binders in an outer jaxpr eqn representing an application
of that inner jaxpr. But that logic was written in terms of checking for
dropped outputs in the outer jaxpr (since units are typically not
consumed downstream):

https://github.com/google/jax/pull/8227/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686L981

That worked (or at least we never noticed a failure, though now it seems
sketchy...) with the 'classic' `jax.checkpoint` / `jax.remat`
implementation, before #8191, because of how that implementation relied
on tracing-based partial evaluation, which would detect and mark dropped
outputs in the outer jaxpr as part of jaxpr formation.

But then in #8191 we no longer marked dropvars in the same way. That led
to assertion failures, and #8227 attempted to fix those. That fix made
sense with the new remat implementation, but not the old one! (In the
intervening period I forgot about this unit-dropping convention...)

The fix here is not to rely on dropvars but to more directly encode the
convention that _partial_eval_jaxpr_custom drops unit outputs in the
known jaxpr it produces.
2021-10-15 17:00:44 -07:00
Roy Frostig
f68c0a42c7 drop unused arguments in the jit AOT call path 2021-10-15 16:12:05 -07:00
Jake VanderPlas
afe7e194e9 Fix inaccurate type annotations 2021-10-15 15:35:43 -07:00
Peter Hawkins
af5d3675dd Change default kind for jnp.argsort to stable. Warn if anything other than stable is passed. 2021-10-15 15:43:53 -04:00
Jake VanderPlas
7a2686f366 jnp.bincount: fix corner cases 2021-10-15 12:31:17 -07:00
jax authors
613dc4bffa Merge pull request #8234 from apaszke:all-gather-batching
PiperOrigin-RevId: 403414516
2021-10-15 10:51:33 -07:00
jax authors
a64ce45c8e Merge pull request #8222 from mattjj:document-vmap-axis-name
PiperOrigin-RevId: 403412220
2021-10-15 10:43:00 -07:00
jax authors
3ee94d93c2 Merge pull request #7803 from jakevdp:jnp-take-validation
PiperOrigin-RevId: 403408483
2021-10-15 10:31:29 -07:00
jax authors
8f1d7beace Merge pull request #8217 from LenaMartens:changelist/403115357
PiperOrigin-RevId: 403402974
2021-10-15 10:10:06 -07:00
jax authors
239662efb7 Merge pull request #8228 from mattjj:remat-fix5
PiperOrigin-RevId: 403397740
2021-10-15 09:48:47 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
Jake VanderPlas
a3a6a5b137 jnp.unique: improve efficiency & consolidate implementation 2021-10-15 07:59:40 -07:00
Adam Paszke
49d9affce0 Enable batcher and batched collective rules for tiled all gathers
Fixes #8221.
2021-10-15 14:37:38 +00:00
Marc van Zee
aaf3bb789e Improves support for conv_general_dilated in JAX for models running on the web or mobile through TFLite and TFjs.
### Before this change

Prior to my change, there were a number of limitations to using convolutions for web/mobile:

* No strides other than (1,1) could be used.
* Padding was only possible for values ["VALID", "SAME"]
* Transposed convolutions were unsupported
* Depthwise convolutions were unsupported
* Input could only be provided in a very specific format, which prevented many use cases.

### After this change

After this change, we now can support the following cases:
* Any strides size can be used
* Any padding can be used (VALID, SAME, or custom numbers)
* Transposed convolutions are supported
* Depthwise convolutions are supported
* Input can be provided in any format.

### Impact on examples

Before, most of the Flax examples using convolutions were failing.
After, all convolutions are converting successfully.

PiperOrigin-RevId: 403302738
2021-10-15 01:00:48 -07:00
Matthew Johnson
8bfa5eec8c fix dce logic 2021-10-14 20:43:40 -07:00
jax authors
804b0b39e6 Merge pull request #8227 from mattjj:remat-fix4
PiperOrigin-RevId: 403253854
2021-10-14 19:48:13 -07:00
Matthew Johnson
d1f0c60b7b keep dropvar binders in call_partial_eval_custom_rule
The dropvars indicate that these binders/outputs aren't used in the
outer jaxpr and so they could be dropped, but to drop the binders would
require also editing the called jaxpr to be consistent. For completeness
that editing could involve DCE, which in turn can affect the jaxpr's
inputs.

Instead of doing that bookkeeping, we can just keep the dropvars.
There's a DCE pass to follow in the remat primitive's partial eval rule
which will clean these up.

(This commit also contains unrelated tweaks to comments and strings.)
2021-10-14 19:27:38 -07:00
Matthew Johnson
584aa13360 document axis_name in the vmap docstring
fixes #8220
2021-10-14 13:09:24 -07:00
Jake VanderPlas
c5a8c5c826 jnp.unique: allow fill_value to be a slice 2021-10-14 12:07:29 -07:00
jax authors
26549bdca0 Merge pull request #8219 from mattjj:fix-saved-residuals-utility
PiperOrigin-RevId: 403147655
2021-10-14 11:46:05 -07:00
Matthew Johnson
297f79c1de make saved_residuals utility work w/ literals 2021-10-14 11:32:09 -07:00
Peter Hawkins
e0d23a7ff0 Improve performance of JIT dispatch when output arity is 1.
Building an output tuple has a non-zero cost on TPU. We can avoid it in the output arity 1 case.

PiperOrigin-RevId: 403142765
2021-10-14 11:28:03 -07:00
Lena Martens
f5d8d4bc4e Replace loop with map in RBG batching_rule. 2021-10-14 18:01:03 +01:00
Danilo Jimenez Rezende
de777c8d37 Set dtypes of constant coefficients appropriately based on the state dtypes. This avoids unexpected casting of the output states.
PiperOrigin-RevId: 403116057
2021-10-14 09:41:22 -07:00
jax authors
80e118aa57 Merge pull request #8213 from hawkinsp:testutil
PiperOrigin-RevId: 403114132
2021-10-14 09:33:31 -07:00
Peter Hawkins
c491203bdd Readd jax.test_util.check_jvp and check_vjp to the public JAX API. 2021-10-14 11:55:11 -04:00
Jake VanderPlas
405ada1553 jnp.nonzero: allow fill_value to be a tuple 2021-10-14 08:40:08 -07:00
Matthew Johnson
725fe3abd4 don't automatically use new checkpoint implementation
There's a bug we're struggling to repro.

To use the new checkpoint, just use

```python
from jax.ad_checkpoint import checkpoint
```

rather than `from jax import checkpoint.
2021-10-14 07:09:06 -07:00
Jake VanderPlas
bbbd5e83cd jnp.piecewise: avoid unnecessary recompilation 2021-10-14 05:44:38 -07:00
jax authors
4648f33256 Merge pull request #8207 from mattjj:remat-by-name
PiperOrigin-RevId: 402998585
2021-10-13 21:55:07 -07:00
Matthew Johnson
6741da632f checkpoint_name for checkpoint policies by name
Also add the jax.ad_checkpoint.print_saved_residuals utility.

All this is experimental, and undocumented for now...

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2021-10-13 21:31:02 -07:00
jax authors
eb3b113358 Merge pull request #8206 from jakevdp:unique-fv-indices
PiperOrigin-RevId: 402989001
2021-10-13 20:37:14 -07:00
jax authors
28ac8dbfb7 Merge pull request #8199 from froystig:aot-jit-aval-check
PiperOrigin-RevId: 402966752
2021-10-13 18:06:51 -07:00
Roy Frostig
2f43f336f3 typecheck mesh executable call arguments 2021-10-13 16:28:15 -07:00
Jake VanderPlas
583a6d35e8 jnp.unique: don't apply fill_value to indices 2021-10-13 16:23:14 -07:00
jax authors
6779c840de Merge pull request #8202 from juliuskunze:perm-fix
PiperOrigin-RevId: 402918845
2021-10-13 14:13:10 -07:00
Roy Frostig
9d3dc6f2b0 typecheck XLA compiled computation call arguments 2021-10-13 14:03:09 -07:00
Julius Kunze
1934fd6e65 Cleanup random.permutation 2021-10-13 14:13:00 -06:00
Peter Hawkins
6a45a9236d Remove the _num_buffers attribute from core.AbstractValue.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
2021-10-13 14:35:07 -04:00
jax authors
147f145a66 Merge pull request #8197 from hawkinsp:primitives
PiperOrigin-RevId: 402866468
2021-10-13 10:34:06 -07:00