138 Commits

Author SHA1 Message Date
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
1555ba147c Copybara import of the project:
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:

make device_array.copy() return a device array

PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649 Merge pull request #10069 from jakevdp:devicearray-copy
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
f4b64f48f4 doc: add examples of using partial with jit 2022-03-29 15:43:58 -07:00
Jake VanderPlas
de9a948d1c make device_array.copy() return a device array 2022-03-29 10:33:29 -07:00
jax authors
2b89236d05 Merge pull request #9967 from froystig:aot-arg-trees
PiperOrigin-RevId: 436272758
2022-03-21 12:15:38 -07:00
YouJiacheng
e5b3f0b537 Fix #9969
Fix hessian with options and add regression test
2022-03-22 00:52:57 +08:00
Roy Frostig
4b4dc3c745 track input argument information in one tree at each AOT stage
Both `Lowered` and `Compiled` carry information about input arguments
for which the underlying computation was lowered (namely avals,
donation bits, and the input pytree structure today). This change
rearranges some internals so that all of this information is held
together in a single pytree of structs. Doing so simplifies the fields
of both stage classes and helps ensure the input argument properties
are consistent with one another (e.g. now they must share a consistent
pytree structure by definition).
2022-03-19 17:46:55 -07:00
Roy Frostig
e04ea89b06 introduce a protocol for compilation wrappers 2022-03-17 17:13:33 -07:00
Jake VanderPlas
625a69d39e DOC: more info on disable_jit 2022-03-17 10:17:40 -07:00
Roy Frostig
047488446b factor AOT types out to a stages module 2022-03-15 15:18:15 -07:00
Jean-Baptiste Lespiau
8a85544537 Add the input avals to Lowered and Compiled.
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
Roy Frostig
7890fb7596 remove _one and _zero from public jax.lax module 2022-03-08 12:56:11 -08:00
Roy Frostig
3f88518363 remove three internal functions from public jax.lax module
... namely `_float`, `_input_dtype`, and `_broadcasting_select`.
2022-03-08 12:49:36 -08:00
Jean-Baptiste Lespiau
aeba6b3438 Move the construction of the in-tree up.
PiperOrigin-RevId: 433202494
2022-03-08 07:10:03 -08:00
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
Jean-Baptiste Lespiau
25472c238b Remove an unnecessary condition.
This makes sure jit, xla_computation and pjit share the same logic for processing static arguments.
2022-02-28 11:34:01 +01:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Yash Katariya
98e114da4f Rename unmapped_local_out_avals to out_avals since it can contain global avals (via GDA) as well as local avals.
PiperOrigin-RevId: 430539281
2022-02-23 14:27:00 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Jean-Baptiste Lespiau
607e7033a6 Turn execute_replicated into a class so we can access its fields.
It's more readable than inspecting the internals of a `functools.partial`.

PiperOrigin-RevId: 429523075
2022-02-18 03:18:47 -08:00
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
Jean-Baptiste Lespiau
799ecfa920 Remove e Type annotation for jit an pmap as there are additional attributes on the returned callable.
Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.
2022-02-14 16:39:51 +00:00
Lena Martens
010eb82ad0 Rename wrapper functions to always refer to the JAX api function.
eg. batched_fun -> vmap_fun
2022-02-08 20:10:39 +00:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Matthew Johnson
e186aa3f1e add and test pytree utils for better errors 2022-02-03 17:04:38 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
George Necula
83b818d45c Add more documentation for buffer donation
Fixes: #9237
2022-01-24 09:33:08 +01:00
jax authors
21ddd83615 Merge pull request #8420 from Huizerd:dev/fwd
PiperOrigin-RevId: 423232626
2022-01-20 21:45:53 -08:00
Peter Hawkins
74e4db47da Change the default IR dialect returned by .compiler_ir() to MHLO.
PiperOrigin-RevId: 423091674
2022-01-20 09:50:17 -08:00
Huizerd
d05431a1ff has_aux for jvp and forward-mode AD value_and_grad
Changes:
- revert value_and_grad_fwd
- add has_aux to jacfwd and jacrev
- tests

fix mypy error
2022-01-20 11:13:58 +01:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
Roy Frostig
0ab93a039e custom batching vmap tests 2022-01-05 18:07:20 -08:00
Matthew Johnson
c8a34fe5cc add jax.block_until_ready function
fixes #8536
2021-12-14 11:02:14 -08:00
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Roy Frostig
b980acf375 detect and err on transformation of AOT-compiled function calls 2021-12-07 17:20:27 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Roy Frostig
90361dc345 methods for retrieving IRs from each AOT stage 2021-12-03 14:53:22 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Parker Schuh
46a1033311 Update device_get docs to mention parrallelism. 2021-11-30 10:20:11 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
496e400c71 [x64] Make autodiff respect weak types 2021-11-23 15:04:08 -08:00
jax authors
2ec1488876 Merge pull request #8629 from jakevdp:dtypes-dtype
PiperOrigin-RevId: 411791488
2021-11-23 05:58:15 -08:00
Jake VanderPlas
c4d9c4674f [x64] regularize dtype helpers 2021-11-22 15:35:12 -08:00
Roy Frostig
20a1517eeb factor tuple conversions into common pmap setup logic 2021-11-22 13:49:44 -08:00
Roy Frostig
cf64a945cf refine pmap-related annotations 2021-11-22 13:49:44 -08:00
Roy Frostig
fcdc0a6c1a ahead-of-time lowering and compilation frontend for pmap 2021-11-22 08:33:04 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00