Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:
```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```
Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.
The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix#10818.
However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.
We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
MLIR bytecode is more compact to represent and should be faster to generate and parse.
The previous attempt at this change broke for 0D convolutions. JAX was not ensuring that the padding attribute had the correct [N, 2] shape when N was 0.
PiperOrigin-RevId: 472991661
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
* when converting from a non-bool type to a boolean, lower it as x != 0 rather than convert(x, i1). Convert has truncation semantics, but we are expecting XLA's x != 0 semantics instead.
* revert https://github.com/google/jax/pull/8825 and part of https://github.com/google/jax/pull/8810. PR https://github.com/google/jax/pull/8828 means that we now will never have a non-canonical preferred_element_type, and so the output type is once again always equal to the preferred element type.
PiperOrigin-RevId: 414716056
This avoids non-canonical types showing up in surprising places.
It is possible that some users are specifying a 64-bit type here intentionally, but that seems unlikely. The fix in that case would be to disable non-x64 mode.
PiperOrigin-RevId: 414511197
Similar to the fix to dot_general in https://github.com/google/jax/pull/8810
This is hard to detect from a direct test, except by inspecting the IR, which I'd rather avoid. However the jax2tf tests already catch it since they have a very tight test tolerance.
PiperOrigin-RevId: 414479170
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.
Quite a few primitives still use fallback paths.
PiperOrigin-RevId: 413130158