150 Commits

Author SHA1 Message Date
Sharad Vikram
8031eee7ee Add in runtime tokens for effectful jaxprs 2022-05-03 15:55:07 -07:00
Matthew Johnson
11ad045dfd [remove-units] remove units from partial_eval.py
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.

(Also there are some unrelated removals of dead code in maps.py.)
2022-05-02 13:43:27 -07:00
Matthew Johnson
65bff3c856 [remove-units] avoid unit-generating function in jax.linear_transpose 2022-04-29 16:37:43 -07:00
Matthew Johnson
9fd53bc6f7 [remove-units] prevent ad.py from introducing units 2022-04-26 13:01:01 -07:00
jax authors
5013bd2e3a Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Lucas Beyer
f7b749c99c
Explicit doc note about device_put* async 2022-04-04 23:38:51 +02:00
Jake VanderPlas
4949e78859 Re-land changes from https://github.com/google/jax/pull/10069
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
1555ba147c Copybara import of the project:
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:

make device_array.copy() return a device array

PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649 Merge pull request #10069 from jakevdp:devicearray-copy
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
f4b64f48f4 doc: add examples of using partial with jit 2022-03-29 15:43:58 -07:00
Jake VanderPlas
de9a948d1c make device_array.copy() return a device array 2022-03-29 10:33:29 -07:00
jax authors
2b89236d05 Merge pull request #9967 from froystig:aot-arg-trees
PiperOrigin-RevId: 436272758
2022-03-21 12:15:38 -07:00
YouJiacheng
e5b3f0b537 Fix #9969
Fix hessian with options and add regression test
2022-03-22 00:52:57 +08:00
Roy Frostig
4b4dc3c745 track input argument information in one tree at each AOT stage
Both `Lowered` and `Compiled` carry information about input arguments
for which the underlying computation was lowered (namely avals,
donation bits, and the input pytree structure today). This change
rearranges some internals so that all of this information is held
together in a single pytree of structs. Doing so simplifies the fields
of both stage classes and helps ensure the input argument properties
are consistent with one another (e.g. now they must share a consistent
pytree structure by definition).
2022-03-19 17:46:55 -07:00
Roy Frostig
e04ea89b06 introduce a protocol for compilation wrappers 2022-03-17 17:13:33 -07:00
Jake VanderPlas
625a69d39e DOC: more info on disable_jit 2022-03-17 10:17:40 -07:00
Roy Frostig
047488446b factor AOT types out to a stages module 2022-03-15 15:18:15 -07:00
Jean-Baptiste Lespiau
8a85544537 Add the input avals to Lowered and Compiled.
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
Roy Frostig
7890fb7596 remove _one and _zero from public jax.lax module 2022-03-08 12:56:11 -08:00
Roy Frostig
3f88518363 remove three internal functions from public jax.lax module
... namely `_float`, `_input_dtype`, and `_broadcasting_select`.
2022-03-08 12:49:36 -08:00
Jean-Baptiste Lespiau
aeba6b3438 Move the construction of the in-tree up.
PiperOrigin-RevId: 433202494
2022-03-08 07:10:03 -08:00
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
Jean-Baptiste Lespiau
25472c238b Remove an unnecessary condition.
This makes sure jit, xla_computation and pjit share the same logic for processing static arguments.
2022-02-28 11:34:01 +01:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Yash Katariya
98e114da4f Rename unmapped_local_out_avals to out_avals since it can contain global avals (via GDA) as well as local avals.
PiperOrigin-RevId: 430539281
2022-02-23 14:27:00 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Jean-Baptiste Lespiau
607e7033a6 Turn execute_replicated into a class so we can access its fields.
It's more readable than inspecting the internals of a `functools.partial`.

PiperOrigin-RevId: 429523075
2022-02-18 03:18:47 -08:00
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
Jean-Baptiste Lespiau
799ecfa920 Remove e Type annotation for jit an pmap as there are additional attributes on the returned callable.
Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.
2022-02-14 16:39:51 +00:00
Lena Martens
010eb82ad0 Rename wrapper functions to always refer to the JAX api function.
eg. batched_fun -> vmap_fun
2022-02-08 20:10:39 +00:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Matthew Johnson
e186aa3f1e add and test pytree utils for better errors 2022-02-03 17:04:38 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
George Necula
83b818d45c Add more documentation for buffer donation
Fixes: #9237
2022-01-24 09:33:08 +01:00
jax authors
21ddd83615 Merge pull request #8420 from Huizerd:dev/fwd
PiperOrigin-RevId: 423232626
2022-01-20 21:45:53 -08:00
Peter Hawkins
74e4db47da Change the default IR dialect returned by .compiler_ir() to MHLO.
PiperOrigin-RevId: 423091674
2022-01-20 09:50:17 -08:00
Huizerd
d05431a1ff has_aux for jvp and forward-mode AD value_and_grad
Changes:
- revert value_and_grad_fwd
- add has_aux to jacfwd and jacrev
- tests

fix mypy error
2022-01-20 11:13:58 +01:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
Roy Frostig
0ab93a039e custom batching vmap tests 2022-01-05 18:07:20 -08:00
Matthew Johnson
c8a34fe5cc add jax.block_until_ready function
fixes #8536
2021-12-14 11:02:14 -08:00
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Roy Frostig
b980acf375 detect and err on transformation of AOT-compiled function calls 2021-12-07 17:20:27 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00