79 Commits

Author SHA1 Message Date
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
George Necula
e87a2a5929 [shape_poly] Remove old non_negative support.
This was deprecated in January 2024, replaced by
`core_max_dim(..., 0)`.

PiperOrigin-RevId: 712523579
2025-01-06 07:36:11 -08:00
Jake VanderPlas
8c3c441ee4 jax.nn.one_hot: deprecate non-integer inputs 2024-12-19 07:11:31 -08:00
George Necula
27b024b240 [shape_poly] Improve handling of mod(e, k) == 0 constraints.
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.

Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
2024-12-12 10:31:02 +01:00
George Necula
60f9da5d58 [shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.

Now we add the equality constraints also in the inequality
reasoning state.
2024-12-11 10:51:49 +01:00
George Necula
4e17bea91a [shape_poly] Fix the handling of __pow__ for symbolic dimensions
The code for handling exponentiation was wrong, and there were
no tests.
2024-12-05 11:11:02 +01:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
George Necula
fb68c97a0d [shape_poly] Fix the handling of jvp(lax.sort)
Previously, `jvp(lax.sort)` used a shape-dependent dtype, for
the types of indices (either `int32` or `int64`, depending on
the size of the dimension). For shape polymorphism, input shapes
can affect other intermediate shapes, but not `dtype`s.

In this case it is easy to just use `int46` independent of
the actual shape.
2024-11-12 03:36:05 -08:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
jax authors
6d2c8cf5de Merge pull request #23656 from tchatow:fix-inv
PiperOrigin-RevId: 683112267
2024-10-07 03:38:04 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
tchatow
520980171f Fix jax.numpy.linalg.inv with shape polymorphism 2024-09-24 12:03:06 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
jax authors
d776f1da76 Merge pull request #23470 from gnecula:poly_fix_eq_constraints
PiperOrigin-RevId: 671727351
2024-09-06 05:53:53 -07:00
George Necula
0d8ffd33ab [shape_polyO] Improve handling of equality shape constraints
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.

First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.

Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.

Fixes: #23437
Fixes: #23456
2024-09-06 13:55:38 +03:00
Peter Hawkins
db4be03f02 Disable many eigh tests.
These started failing due to a compiler change internally at Google, but the tests themselves are buggy. It is not correct to compare an eigendecomposition for equality up to a tolerance, because the eigenvalues are sorted, and all it takes is a tiny perturbation to reorder the eigenvalues and eigenvectors, which leads to a result that looks very different.

PiperOrigin-RevId: 669346013
2024-08-30 09:09:37 -07:00
Dan Foreman-Mackey
d49d070f0e Skip shape polymorphism tests that are incompatible with released jaxlib version.
PiperOrigin-RevId: 665893050
2024-08-21 08:35:35 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Dan Foreman-Mackey
4eb5ef28ef Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old.
PiperOrigin-RevId: 662088901
2024-08-12 08:13:27 -07:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
George Necula
093b92be8e Reverts 5216719996d4468f750725ef70cef6f97ac45c27
PiperOrigin-RevId: 653237245
2024-07-17 08:10:01 -07:00
George Necula
7817b6785b [shape_poly] Expand the support for shape polymorphism for jnp.pad
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.

While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.

And I discovered a bug in the implementation of `divmod(0, b)`.
2024-07-15 17:04:54 +02:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
jax authors
02b5d4769d Swap operands of dot if the LHS is fed by a parameter
PiperOrigin-RevId: 642090766
2024-06-10 18:33:05 -07:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
Jake VanderPlas
f04a2279a5 shape_poly_test: adjust configs via jtu.global_config_context 2024-06-05 10:45:56 -07:00
George Necula
dbad518d2b [shape_poly] Add limited support for lax.approx_top_k.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.

We also add a backwards compatibility test.

PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
jax authors
f72b0f0ca6 Merge pull request #21504 from gnecula:poly_approx
PiperOrigin-RevId: 638550165
2024-05-30 00:22:24 -07:00
George Necula
c6a47316be [shape_poly] Fixes for approx_top_k when aggregated_to_topk=True
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.

Add a set of test cases for shape polymorphism for approx_top_k.

The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.

The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
2024-05-30 04:17:13 +03:00
George Necula
87b81fc768 [shape_polyO] Add support for jnp.tril. 2024-05-30 02:53:00 +03:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
George Necula
6705a96b56 [shape_poly] Cleaning up naming of terms and factors.
In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.

Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.

At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
2024-02-21 09:18:22 +01:00
George Necula
30ddc400b8 [shape_poly] Fix handling of stride_in_dim with symbolic stride.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
2024-02-19 12:36:26 +01:00
George Necula
bb57fb71e2 [shape_poly] Performance improvements for symbolic dimension manipulations (step 3)
We make the following improvements:

  * Add a `linear_combination` function to use for computing
    linear combinations fo symbolic expressions. E.g, `a - b` used
    to involve 2 operations: "-1 * b" and "a + -1*b".
  * Change the representation of terms (_DimMon) from a dictionary
    mapping factors (_DimAtom) to exponents, into a sorted tuple of
    pairs (factor, exponent). This is worthwhile because in almost
    all cases a term contains a single factor. Everywhere we used
    `term.items()` now we use `term._factors`.
  * Make the computation of `._hash` lazy. Previously, we used dictionaries
    heavily for symbolic expressions and we always needed the hash value,
    now we use dictionaries less.
  * Replace `t.degree` with `t.is_constant`.
  * Add `__slots__` to the representation of symbolic expressions

Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`

After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
2024-02-16 17:33:34 +01:00
George Necula
eb9caf0d16 [shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:

  * Cache the state of the decision procedure after we process the explicit
    constraints, and reuse it for new decisions.
  * Rationalize the usage of add_implicit_constraints. We used to call it
    conservatively, too often. Now we call it only once for each explicit constraint,
    and once for each bounds decision we make. Then, in the add_implicit_constraints
    we call it recursively when we encounter new sub-expressions.
  * Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
    since we should only need it for reporting error messages.

This speeds up inequality reasoning:

Before:
```
In [1]:     from jax.experimental import export
   ...:     from jax import core
   ...:     a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])

In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```

After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-15 18:07:57 +01:00
George Necula
18698a1f19 [shape_poly] Add support for jnp.split 2024-02-15 14:43:41 +01:00
George Necula
ed735608b5 [shape_poly] Improve the symbolic expressions pretty-printer and parser.
Now we allow parsing: "+ a", "-a ", "-b + a".
Also we print "- a" instead of "-1*a".
2024-02-14 12:03:42 +02:00