599 Commits

Author SHA1 Message Date
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Matthew Johnson
b90daf9cda custom_vjp: automatically handle float0 cotangents 2021-08-17 16:18:57 -07:00
Matthew Johnson
2e6a30a595 always use same object for vmap temp axis name 2021-08-13 14:54:17 -07:00
Matthew Johnson
a0b9946a30 add regression test for #7613 2021-08-12 21:55:51 -07:00
Markus Kunesch
5552db724d Do not unflatten trees with None values in grad.
When checking the data type of the dynamic arguments in jax.value_and_grad the
PyTree is unflattened with `None` (the output of `_check_input_dtype_grad`) as
value for each leaf. This causes an issue if a custom PyTree does not accept
None as a value for the leaves (issue #7546) even though the tree that is
returned from the data type check is never used.

This commit solves this issue by iterating over tree_leaves when checking data
types rather than using tree_map.
2021-08-10 20:13:12 +00:00
Peter Hawkins
beddf598bd Add @jit decorators to jax.numpy operators.
By wrapping common operators in `jit`, we get a number of benefits:
* `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive.
* `jit` allows us to cache and reuse logic such as broadcasting and type promotion.

One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars.

On my laptop, this benchmark improves from 95us to 4us:

```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jax.device_put(7)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [3]: %timeit jnp.add(x, x).block_until_ready()
4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

PiperOrigin-RevId: 389871450
2021-08-10 06:49:28 -07:00
Peter Hawkins
b232d09440 Enable flake8 checks for spaces around operators. 2021-07-30 08:45:38 -04:00
Lena Martens
2190734637 Add tracers to LeakChecker error, and filter out false positives this way.
If we can't find any hanging tracers in the gc.get_referrers chain, is it
really a leak? Probably not!
2021-07-29 15:45:24 +01:00
Peter Hawkins
1c9dbd12c4 Remove Python 3.6 compatibility code. 2021-07-29 09:09:02 -04:00
Lena Martens
19ee7b22e1 Expose UnexpectedTracerError and add docs. 2021-07-27 23:23:28 +01:00
Roy Frostig
52f0cbe35f improve and add to closure_convert testing
* Test closure conversion with mixed values in the closure, one
  participating in AD and the other not.
* Simplify the basic closure_convert test and give its intermediates
  more descriptive names.
2021-07-26 18:50:45 -07:00
jax authors
63357689af Merge pull request #7191 from inailuig:linear_transpose_partial
PiperOrigin-RevId: 386259543
2021-07-22 09:40:43 -07:00
jax authors
08eb89ec07 Merge pull request #7350 from google:jflynn
PiperOrigin-RevId: 386254653
2021-07-22 09:16:57 -07:00
Matthew Johnson
fd5dec7b68 fix xla.lower_fun axis env issue 2021-07-21 21:19:38 -07:00
zafarali
00a273d1fe Add assert raises check. 2021-07-21 16:12:52 -04:00
zafarali
8143773b79 Remove unnecessary message 2021-07-21 15:04:54 -04:00
zafarali
215e7d618f Make jax arrays not pass isinstance checks for hashable. 2021-07-21 14:50:10 -04:00
Lena Martens
24c9a933d6 Add shape and dtype of leaked tracer to UnexpectedTracerError. 2021-07-21 17:50:44 +01:00
Clemens Giuliani
5cf31667cb Test that the results of vjp, linarize and linear_transpose are jit compatible 2021-07-21 15:59:07 +02:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Adam Paszke
1c1ec79edd Clarify the error message for out-of-bounds in_axes in pmap and vmap
Fixes #5201.
2021-07-14 12:11:06 +00:00
Qiao Zhang
72b436f9ed Add a test to repro bugs in TFRT CPU backend. 2021-07-12 14:53:49 -07:00
George Necula
022514e04c Updated the error message 2021-07-11 10:49:30 +03:00
George Necula
5520fcb59f Improve error message when vjp is called with cotangent of wrong shape.
Previously the error was an internal assertion error.
2021-07-10 19:12:11 +03:00
James Bradbury
8e86952ee4 AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
2021-07-08 12:06:29 -07:00
Matthew Johnson
a0eb1126e4 remat: don't apply cse-foiling widget to primal 2021-06-30 09:29:47 -07:00
Jake VanderPlas
c8e571ad84 Allow suppression of GPU warning via jax_platform_name 2021-06-28 12:54:21 -07:00
Peter Hawkins
15fe683945 Disable float0 tests that fail under NumPy 1.21.
https://github.com/numpy/numpy/issues/19305
2021-06-24 11:30:16 -04:00
Peter Hawkins
75c9bf01f3 Fix most test failures under NumPy 1.21. 2021-06-22 16:31:44 -04:00
Peter Hawkins
d5ba87ad7f Add a device_put handler for tokens.
Fixes bug with tokens passed to trivial computations.
2021-06-07 16:19:14 -04:00
Qiao Zhang
f5f62ce0d5 Remove stale cache_clear in api_test. 2021-06-07 12:10:54 -07:00
Qiao Zhang
6d77b9f447 [JAX:TFRT:CPU] Fix TfrtCpuClient::BufferFromHostBuffer bug when shape has
non-default layout (e.g., from TPU).

PiperOrigin-RevId: 375165974
2021-05-21 14:36:35 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Peter Hawkins
e6fb6e0881 Improve type error when an object dtype is passed to an operator like +.
Fixes #856.
2021-05-10 11:52:12 -04:00
Parker Schuh
92246017d2
Revert "Use convert_element_type instead of device_put_raw." 2021-05-06 20:19:23 -07:00
Parker Schuh
9d3e535ad2
Merge branch 'master' into convert_element 2021-05-06 13:18:01 -07:00
jax authors
3c6a41eb9c Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
2021-05-05 14:45:57 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
jax authors
43d273399b Merge pull request #6377 from boyentenbi:changelist/367423156
PiperOrigin-RevId: 370628969
2021-04-27 01:28:07 -07:00
Matthew Johnson
8f434539e1 re-enable a working test 2021-04-24 15:18:26 -07:00
Lena Martens
b244e2b8c8 Add eval_shape to the UnexpectedTracerError too. 2021-04-23 14:46:34 +01:00
Matthew Johnson
ba9233b9b6 prune trivial convert_element_types from jaxprs
also add a test for not performing H2D transfers while tracing jnp.array
2021-04-22 12:46:26 -07:00
Peter Choy
eb9d6e4d21 Pass axis name to _match_axes and add to error message. 2021-04-22 13:34:04 +00:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
jax authors
bbc7be064c Merge pull request #6239 from j-towns:lt-allow-integers
PiperOrigin-RevId: 369467931
2021-04-20 10:23:10 -07:00