If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.
I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
This also adds an `mlir.delegate_lowering` helper that takes optional
replacements to the lowering context and restores/updates it on
return. This is used to define a gather lowering rule in tests here,
by calling out to the base gather lowering rule. I picked it off of a
working branch that uses it as well (for key-eltyped lowering rules).
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.
Within jaxprs, we also need to handle transformations with these types.
In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
* allowing ShapedArray to have any object for its 'dtype' attribute
* added core.custom_eltype set
* extended existing handlers for ShapedArray to call the corresponding custom
element type handlers
* mlir lowerings of some fully-element-type-polymorphic primitives
* tests
In this PR, we only actually use these new extension points in tests.
The applications to come that we have in mind are:
* arrays of prngkeys (and even custom prngs, as well as reuse error checking)
* arrays of bounded int type for dynamic shapes (and especially raggedness)
* float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.
Jargon-wise, we may want to distinguish:
* 'eltype' meaning jaxpr element types
* 'dtype' meaning numpy dtypes (an existing convention)
* 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.
We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
* [x] broadcast
* [ ] reshape
* [x] transpose
* [ ] pad
* [x] slice, dynamic_slice, dynamic_update_slice
* [ ] concatenate
* [ ] all_to_all, gather, scatter, all_gather, collective_permute
* [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.
We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.
Co-authored-by: Roy Frostig <frostig@google.com>
We were previously missing a shape check for this, meaning that if non-scalars were passed we would receive a cryptic MHLO shape error rather than a helpful Python backtrace.
PiperOrigin-RevId: 452682084
`mhlo.scatter` as before can still take the three arguments `(operand, scatter_indices, updates)`, but now it can operate on multiple operands with their own sets of updates, `(operands..., scatter_indicies, updates...)`
Following the change in HLO,
6c9157ffb6
In order to support JAX FR: https://github.com/google/jax/issues/10079
This CL updates the op definition to allow for variadic scatter, and stubs out all clients of ScatterOp so that support can be implemented in following CLs.
PiperOrigin-RevId: 452559990
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.
The previous behavior can be approximated using the "clip" or "wrap" fill modes.
PiperOrigin-RevId: 445130143
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.
This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.
After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.
Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.
Issues: https://github.com/google/jax/issues/278https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.
https://github.com/google/jax/issues/9296
PiperOrigin-RevId: 423917753
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.
The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.
PiperOrigin-RevId: 414704700
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.
Quite a few primitives still use fallback paths.
PiperOrigin-RevId: 413130158
Also handle a limited form of shape polymorphism, where
the `operand.shape - update.shape` is a constant in the scatter dimensions,
even when the shapes may contain dimension variables.