Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
PiperOrigin-RevId: 551915175
An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype.
PiperOrigin-RevId: 551627521
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.
One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.
PiperOrigin-RevId: 549301796
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).
Co-authored-by: Matthew Johnson <mattjj@google.com>
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.
This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.
PiperOrigin-RevId: 543912445
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.
PiperOrigin-RevId: 543495381
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:
```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```
Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```