205 Commits

Author SHA1 Message Date
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Sharad Vikram
fe040cc01e Cleaning up eager pmap implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 11:10:16 -07:00
Matthew Johnson
5310515c80 Initial implementation of eager pmap
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:21:55 -07:00
Sharad Vikram
88f2b5e86d Add functionality for "pure" callbacks
Also avoids using CPP dispatch path when host callbacks are involved

PiperOrigin-RevId: 467270949
2022-08-12 12:39:53 -07:00
Yash Katariya
18b6a32db2 Make all pmap tests pass with Array! I am skipping all soft pmap tests for now.
PiperOrigin-RevId: 467264992
2022-08-12 12:09:49 -07:00
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Roy Frostig
7d494a3852 update checkpoint attributes according to functools.wraps
This updates the signature in addition to `__doc__`, and that gets
picked up by generated API docs.
2022-08-10 13:33:07 -07:00
jax authors
8b2e4f975c Merge pull request #11825 from mattjj:fix-type-annotation
PiperOrigin-RevId: 466550958
2022-08-09 20:21:10 -07:00
Matthew Johnson
d76754e40e fix type annotation on remat 2022-08-09 19:57:40 -07:00
Parker Schuh
01df754630
Remove docs 2022-08-09 12:36:49 -07:00
Parker Schuh
8fb957350c Add spmd_axis_name to vmap to allow constraining mapped PartitionSpecs. 2022-08-08 19:41:42 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
Matthew Johnson
e3a92d52ba prepare to switch to new remat
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
jax authors
75d69725c3 Merge pull request #11640 from pschuh:pmap-shaped-array
PiperOrigin-RevId: 464623040
2022-08-01 14:22:41 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Parker Schuh
3344f89a1e Allow jax.pmap(f).lower() to take jax.core.ShapedArray.
This will be useful for calling lower on the result of
jax.eval_shape().
2022-07-27 19:17:37 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
9ee6cacdc8 Merge pull request #11540 from gnecula:ds_check_flag
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
Yash Katariya
9f914a93d6 Replace input sharding_specs with in_shardings in InputsHandler
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
George Necula
2106d65561 [dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
Yash Katariya
ea627b807b Replace out_specs with out_shardings and remove out_indices in ResultsHandler.
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Yash Katariya
dce8f64b40 Make device_put_sharded and device_put_replicated return Arrays.
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Jake VanderPlas
abcfaec6e3 DOC: clarify variable names 2022-06-21 13:20:53 -07:00
Neil Girdhar
7697616b85 Annotate vmap 2022-06-14 15:41:47 -04:00
Peter Hawkins
78312a7fff Add an undocumented method on jit() functions to clear the function cache. 2022-06-10 18:36:18 -07:00
Peter Hawkins
b32f83d84d Clarify that functions passed to jax.jit must be weakly referenceable. 2022-06-10 12:21:23 -07:00
Sharad Vikram
289610eb02 Add a public facing named_scope function to allow adding to the name stack. 2022-06-08 17:23:57 -07:00
Sharad Vikram
426c7356fb Guard has_explicit_device with xla_client version 2022-06-02 12:03:27 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
jax authors
094e706498 Merge pull request #10823 from LenaMartens:changelist/450914215
PiperOrigin-RevId: 451427821
2022-05-27 10:39:24 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
Lena Martens
f8f5a5dca3 Add notes in buffer donation FAQ about key-word args limitation. 2022-05-25 15:33:04 +01:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00