Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 502568204
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 497171518
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 497079129
--
a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>:
Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.
PiperOrigin-RevId: 496883024
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.
The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.
While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.
.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.
I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
This also adds an `mlir.delegate_lowering` helper that takes optional
replacements to the lowering context and restores/updates it on
return. This is used to define a gather lowering rule in tests here,
by calling out to the base gather lowering rule. I picked it off of a
working branch that uses it as well (for key-eltyped lowering rules).
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.
Within jaxprs, we also need to handle transformations with these types.
In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
* allowing ShapedArray to have any object for its 'dtype' attribute
* added core.custom_eltype set
* extended existing handlers for ShapedArray to call the corresponding custom
element type handlers
* mlir lowerings of some fully-element-type-polymorphic primitives
* tests
In this PR, we only actually use these new extension points in tests.
The applications to come that we have in mind are:
* arrays of prngkeys (and even custom prngs, as well as reuse error checking)
* arrays of bounded int type for dynamic shapes (and especially raggedness)
* float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.
Jargon-wise, we may want to distinguish:
* 'eltype' meaning jaxpr element types
* 'dtype' meaning numpy dtypes (an existing convention)
* 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.
We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
* [x] broadcast
* [ ] reshape
* [x] transpose
* [ ] pad
* [x] slice, dynamic_slice, dynamic_update_slice
* [ ] concatenate
* [ ] all_to_all, gather, scatter, all_gather, collective_permute
* [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.
We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.
Co-authored-by: Roy Frostig <frostig@google.com>
We were previously missing a shape check for this, meaning that if non-scalars were passed we would receive a cryptic MHLO shape error rather than a helpful Python backtrace.
PiperOrigin-RevId: 452682084
`mhlo.scatter` as before can still take the three arguments `(operand, scatter_indices, updates)`, but now it can operate on multiple operands with their own sets of updates, `(operands..., scatter_indicies, updates...)`
Following the change in HLO,
6c9157ffb6
In order to support JAX FR: https://github.com/google/jax/issues/10079
This CL updates the op definition to allow for variadic scatter, and stubs out all clients of ScatterOp so that support can be implemented in following CLs.
PiperOrigin-RevId: 452559990
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.
The previous behavior can be approximated using the "clip" or "wrap" fill modes.
PiperOrigin-RevId: 445130143
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.
This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.
After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.
Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.
Issues: https://github.com/google/jax/issues/278https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.