53 Commits

Author SHA1 Message Date
Adam Paszke
08685efb22 Keep axis_env initialized during jaxpr_subcomp
``jaxpr_subcomp`` likes to lower control-flow primitives by tracing them
again as JAX callables, but they're all axis primitives now and so they
do require a properly initialized axis env.
2021-10-01 11:14:55 +00:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
jax authors
b992034eaf Merge pull request #7983 from jakevdp:fix-std-basis
PiperOrigin-RevId: 398502689
2021-09-23 09:15:51 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jake VanderPlas
0957e81655 Use traced identity in jacobian std_basis 2021-09-22 16:08:18 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Sharad Vikram
cc3e197991 Combine initial_style_batchers with collective_rules 2021-09-09 11:23:51 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Adam Paszke
1158530faa Remove axis name from named_shape when unmapping avals
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.

PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
Jean-Baptiste Lespiau
9c782e2289 Move ShardedDeviceArray & PmapFunction to the raw C API and implement pickling/unpickling.
PiperOrigin-RevId: 395256774
2021-09-07 08:50:48 -07:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Matthew Johnson
037e420568 reviewer comments 2021-08-25 20:46:32 -07:00
Matthew Johnson
542641ca87 rejames/reblake implementation 2021-08-25 20:46:32 -07:00
jax authors
aea51c83e4 Merge pull request #7188 from cccntu:add-doc-jax.device_get
PiperOrigin-RevId: 392521558
2021-08-23 14:39:37 -07:00
Jake VanderPlas
00f36173bd Specify weak_type in DeviceArray repr 2021-08-23 13:19:33 -07:00
Jonathan Chang
8536780455 Add documentation for jax.device_get 2021-08-24 00:57:17 +08:00
jax authors
05913a912e Merge pull request #7386 from superbobry:xla-type-annotations
PiperOrigin-RevId: 391381136
2021-08-17 14:43:43 -07:00
Jean-Baptiste Lespiau
f6f1debf70 Add post_hook support for pmap, to support debug_nans and debug_infs.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.

PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Sergei Lebedev
af41a959d3 Most of JAX now uses concrete types for things defined in jaxlib.xla_client
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
2021-08-16 20:33:36 +01:00
Matthew Johnson
2e6a30a595 always use same object for vmap temp axis name 2021-08-13 14:54:17 -07:00
Jean-Baptiste Lespiau
7821d07c9d Efficient C++ pmap implementation.
PiperOrigin-RevId: 390595576
2021-08-13 06:05:48 -07:00
Markus Kunesch
5552db724d Do not unflatten trees with None values in grad.
When checking the data type of the dynamic arguments in jax.value_and_grad the
PyTree is unflattened with `None` (the output of `_check_input_dtype_grad`) as
value for each leaf. This causes an issue if a custom PyTree does not accept
None as a value for the leaves (issue #7546) even though the tree that is
returned from the data type check is never used.

This commit solves this issue by iterating over tree_leaves when checking data
types rather than using tree_map.
2021-08-10 20:13:12 +00:00
Lena Martens
2f9caf3d64 Ensure reduce_axes is a tuple. 2021-08-10 18:49:29 +01:00
Jean-Baptiste Lespiau
45aaf8a647 Make it possible to return a C++ ShardedDeviceArray.
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.

I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.

Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).

- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
  (a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
  (b) Adding the same methods as the Python SDA.

NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
2021-08-10 07:16:24 -07:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
Jean-Baptiste Lespiau
ad4c670f37 Add Python code for the future C++ pmap and pass the data to C++ as a namedtuple.
PiperOrigin-RevId: 388788330
2021-08-04 14:46:51 -07:00
Clemens Giuliani
26eef81b51 turn transposed_fun into a PyTree 2021-07-21 15:58:01 +02:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Adam Paszke
1c1ec79edd Clarify the error message for out-of-bounds in_axes in pmap and vmap
Fixes #5201.
2021-07-14 12:11:06 +00:00
George Necula
022514e04c Updated the error message 2021-07-11 10:49:30 +03:00
George Necula
5520fcb59f Improve error message when vjp is called with cotangent of wrong shape.
Previously the error was an internal assertion error.
2021-07-10 19:12:11 +03:00
James Bradbury
8e86952ee4 AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
2021-07-08 12:06:29 -07:00
James Martens
f925b62ea0 Clarifying docstring for devices argument of pmap.
PiperOrigin-RevId: 383486168
2021-07-07 13:51:11 -07:00
Matthew Johnson
a0eb1126e4 remat: don't apply cse-foiling widget to primal 2021-06-30 09:29:47 -07:00
Clemens Giuliani
3041c18250 turn lifted_jvp into a PyTree 2021-06-24 23:55:49 +02:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
jax authors
3c6a41eb9c Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
2021-05-05 14:45:57 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
Matthew Johnson
fe297e39ca add 'inline' to jit docstring 2021-05-01 12:32:44 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
Lena Martens
b244e2b8c8 Add eval_shape to the UnexpectedTracerError too. 2021-04-23 14:46:34 +01:00
Lena Martens
deb2227f4a Make sure the out_axes in the HashableFunction closure are hashable.
By flattening them before putting them in the closure.
2021-04-21 12:32:19 +01:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
bbc7be064c Merge pull request #6239 from j-towns:lt-allow-integers
PiperOrigin-RevId: 369467931
2021-04-20 10:23:10 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00