190 Commits

Author SHA1 Message Date
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
Matthew Johnson
e3a92d52ba prepare to switch to new remat
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
jax authors
75d69725c3 Merge pull request #11640 from pschuh:pmap-shaped-array
PiperOrigin-RevId: 464623040
2022-08-01 14:22:41 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Parker Schuh
3344f89a1e Allow jax.pmap(f).lower() to take jax.core.ShapedArray.
This will be useful for calling lower on the result of
jax.eval_shape().
2022-07-27 19:17:37 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
9ee6cacdc8 Merge pull request #11540 from gnecula:ds_check_flag
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
Yash Katariya
9f914a93d6 Replace input sharding_specs with in_shardings in InputsHandler
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
George Necula
2106d65561 [dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
Yash Katariya
ea627b807b Replace out_specs with out_shardings and remove out_indices in ResultsHandler.
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Yash Katariya
dce8f64b40 Make device_put_sharded and device_put_replicated return Arrays.
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Jake VanderPlas
abcfaec6e3 DOC: clarify variable names 2022-06-21 13:20:53 -07:00
Neil Girdhar
7697616b85 Annotate vmap 2022-06-14 15:41:47 -04:00
Peter Hawkins
78312a7fff Add an undocumented method on jit() functions to clear the function cache. 2022-06-10 18:36:18 -07:00
Peter Hawkins
b32f83d84d Clarify that functions passed to jax.jit must be weakly referenceable. 2022-06-10 12:21:23 -07:00
Sharad Vikram
289610eb02 Add a public facing named_scope function to allow adding to the name stack. 2022-06-08 17:23:57 -07:00
Sharad Vikram
426c7356fb Guard has_explicit_device with xla_client version 2022-06-02 12:03:27 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
jax authors
094e706498 Merge pull request #10823 from LenaMartens:changelist/450914215
PiperOrigin-RevId: 451427821
2022-05-27 10:39:24 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
Lena Martens
f8f5a5dca3 Add notes in buffer donation FAQ about key-word args limitation. 2022-05-25 15:33:04 +01:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Sharad Vikram
268b4be21b Add output token for unordered effects
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.

PiperOrigin-RevId: 449106281
2022-05-16 18:56:33 -07:00
Skye Wanderman-Milne
f26b866e08 Add jax.default_device context manager
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
2022-05-07 00:31:00 +00:00
Jean-Baptiste Lespiau
5c838d2f6f Add an option when lowering to not remove unused arguments.
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.

PiperOrigin-RevId: 446391558
2022-05-04 01:22:14 -07:00
Sharad Vikram
ef982cfa8c Attach keepalive to executable 2022-05-03 20:08:03 -07:00
Sharad Vikram
8031eee7ee Add in runtime tokens for effectful jaxprs 2022-05-03 15:55:07 -07:00
Matthew Johnson
11ad045dfd [remove-units] remove units from partial_eval.py
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.

(Also there are some unrelated removals of dead code in maps.py.)
2022-05-02 13:43:27 -07:00
Matthew Johnson
65bff3c856 [remove-units] avoid unit-generating function in jax.linear_transpose 2022-04-29 16:37:43 -07:00
Matthew Johnson
9fd53bc6f7 [remove-units] prevent ad.py from introducing units 2022-04-26 13:01:01 -07:00
jax authors
5013bd2e3a Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00