* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.
PiperOrigin-RevId: 713411171
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.
Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.
We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.
Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.
Now we add the equality constraints also in the inequality
reasoning state.
Previously, `jvp(lax.sort)` used a shape-dependent dtype, for
the types of indices (either `int32` or `int64`, depending on
the size of the dimension). For shape polymorphism, input shapes
can affect other intermediate shapes, but not `dtype`s.
In this case it is easy to just use `int46` independent of
the actual shape.
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.
The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.
Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.
PiperOrigin-RevId: 682415506
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.
The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.
PiperOrigin-RevId: 682312752
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.
First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.
Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.
Fixes: #23437Fixes: #23456
These started failing due to a compiler change internally at Google, but the tests themselves are buggy. It is not correct to compare an eigendecomposition for equality up to a tolerance, because the eigenvalues are sorted, and all it takes is a tiny perturbation to reorder the eigenvalues and eigenvectors, which leads to a result that looks very different.
PiperOrigin-RevId: 669346013
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.
PiperOrigin-RevId: 662024940
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.
While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.
And I discovered a bug in the implementation of `divmod(0, b)`.
The functionality comes from the jax.experimental.export
module, which will be deprecated.
The following APIs are introduced:
```
from jax import export
def f(...): ...
ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)
blob: bytearray = ex.serialize()
rehydrated: export.Export = export.deserialize(blob)
def caller(...):
... rehydrated.call(*args, **kwargs)
```
Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.
Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:
* Instead of `jax.experimental.export.call(exp)` we now write
`exp.call`
* The `jax.experimental.export.export` allowed the function
argument to be any Python callable and it would wrap it with
a `jax.jit`. This is not supported anymore by export, and instead
the user must use `jax.jit`.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.
We also add a backwards compatibility test.
PiperOrigin-RevId: 640557581
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.
Add a set of test cases for shape polymorphism for approx_top_k.
The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.
The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.
Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.
At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
We make the following improvements:
* Add a `linear_combination` function to use for computing
linear combinations fo symbolic expressions. E.g, `a - b` used
to involve 2 operations: "-1 * b" and "a + -1*b".
* Change the representation of terms (_DimMon) from a dictionary
mapping factors (_DimAtom) to exponents, into a sorted tuple of
pairs (factor, exponent). This is worthwhile because in almost
all cases a term contains a single factor. Everywhere we used
`term.items()` now we use `term._factors`.
* Make the computation of `._hash` lazy. Previously, we used dictionaries
heavily for symbolic expressions and we always needed the hash value,
now we use dictionaries less.
* Replace `t.degree` with `t.is_constant`.
* Add `__slots__` to the representation of symbolic expressions
Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`
After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```