From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.
Within jaxprs, we also need to handle transformations with these types.
In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
* allowing ShapedArray to have any object for its 'dtype' attribute
* added core.custom_eltype set
* extended existing handlers for ShapedArray to call the corresponding custom
element type handlers
* mlir lowerings of some fully-element-type-polymorphic primitives
* tests
In this PR, we only actually use these new extension points in tests.
The applications to come that we have in mind are:
* arrays of prngkeys (and even custom prngs, as well as reuse error checking)
* arrays of bounded int type for dynamic shapes (and especially raggedness)
* float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.
Jargon-wise, we may want to distinguish:
* 'eltype' meaning jaxpr element types
* 'dtype' meaning numpy dtypes (an existing convention)
* 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.
We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
* [x] broadcast
* [ ] reshape
* [x] transpose
* [ ] pad
* [x] slice, dynamic_slice, dynamic_update_slice
* [ ] concatenate
* [ ] all_to_all, gather, scatter, all_gather, collective_permute
* [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.
We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.
Co-authored-by: Roy Frostig <frostig@google.com>
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
The checks are:
(1) Check if the in_axes given to pmap matches the sharding of Array.
(2) Check if devices in `array.sharding` is equal to the devices provided to pmap
(3) Check if devices for all array inputs are the same.
(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).
PiperOrigin-RevId: 456567562
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.
PiperOrigin-RevId: 449106281
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.
PiperOrigin-RevId: 446391558
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.
(Also there are some unrelated removals of dead code in maps.py.)