64 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
jax authors
d776f1da76 Merge pull request #23470 from gnecula:poly_fix_eq_constraints
PiperOrigin-RevId: 671727351
2024-09-06 05:53:53 -07:00
George Necula
0d8ffd33ab [shape_polyO] Improve handling of equality shape constraints
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.

First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.

Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.

Fixes: #23437
Fixes: #23456
2024-09-06 13:55:38 +03:00
Peter Hawkins
db4be03f02 Disable many eigh tests.
These started failing due to a compiler change internally at Google, but the tests themselves are buggy. It is not correct to compare an eigendecomposition for equality up to a tolerance, because the eigenvalues are sorted, and all it takes is a tiny perturbation to reorder the eigenvalues and eigenvectors, which leads to a result that looks very different.

PiperOrigin-RevId: 669346013
2024-08-30 09:09:37 -07:00
Dan Foreman-Mackey
d49d070f0e Skip shape polymorphism tests that are incompatible with released jaxlib version.
PiperOrigin-RevId: 665893050
2024-08-21 08:35:35 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Dan Foreman-Mackey
4eb5ef28ef Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old.
PiperOrigin-RevId: 662088901
2024-08-12 08:13:27 -07:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
George Necula
093b92be8e Reverts 5216719996d4468f750725ef70cef6f97ac45c27
PiperOrigin-RevId: 653237245
2024-07-17 08:10:01 -07:00
George Necula
7817b6785b [shape_poly] Expand the support for shape polymorphism for jnp.pad
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.

While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.

And I discovered a bug in the implementation of `divmod(0, b)`.
2024-07-15 17:04:54 +02:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
jax authors
02b5d4769d Swap operands of dot if the LHS is fed by a parameter
PiperOrigin-RevId: 642090766
2024-06-10 18:33:05 -07:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
Jake VanderPlas
f04a2279a5 shape_poly_test: adjust configs via jtu.global_config_context 2024-06-05 10:45:56 -07:00
George Necula
dbad518d2b [shape_poly] Add limited support for lax.approx_top_k.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.

We also add a backwards compatibility test.

PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
jax authors
f72b0f0ca6 Merge pull request #21504 from gnecula:poly_approx
PiperOrigin-RevId: 638550165
2024-05-30 00:22:24 -07:00
George Necula
c6a47316be [shape_poly] Fixes for approx_top_k when aggregated_to_topk=True
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.

Add a set of test cases for shape polymorphism for approx_top_k.

The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.

The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
2024-05-30 04:17:13 +03:00
George Necula
87b81fc768 [shape_polyO] Add support for jnp.tril. 2024-05-30 02:53:00 +03:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
George Necula
6705a96b56 [shape_poly] Cleaning up naming of terms and factors.
In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.

Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.

At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
2024-02-21 09:18:22 +01:00
George Necula
30ddc400b8 [shape_poly] Fix handling of stride_in_dim with symbolic stride.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
2024-02-19 12:36:26 +01:00
George Necula
bb57fb71e2 [shape_poly] Performance improvements for symbolic dimension manipulations (step 3)
We make the following improvements:

  * Add a `linear_combination` function to use for computing
    linear combinations fo symbolic expressions. E.g, `a - b` used
    to involve 2 operations: "-1 * b" and "a + -1*b".
  * Change the representation of terms (_DimMon) from a dictionary
    mapping factors (_DimAtom) to exponents, into a sorted tuple of
    pairs (factor, exponent). This is worthwhile because in almost
    all cases a term contains a single factor. Everywhere we used
    `term.items()` now we use `term._factors`.
  * Make the computation of `._hash` lazy. Previously, we used dictionaries
    heavily for symbolic expressions and we always needed the hash value,
    now we use dictionaries less.
  * Replace `t.degree` with `t.is_constant`.
  * Add `__slots__` to the representation of symbolic expressions

Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`

After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
2024-02-16 17:33:34 +01:00
George Necula
eb9caf0d16 [shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:

  * Cache the state of the decision procedure after we process the explicit
    constraints, and reuse it for new decisions.
  * Rationalize the usage of add_implicit_constraints. We used to call it
    conservatively, too often. Now we call it only once for each explicit constraint,
    and once for each bounds decision we make. Then, in the add_implicit_constraints
    we call it recursively when we encounter new sub-expressions.
  * Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
    since we should only need it for reporting error messages.

This speeds up inequality reasoning:

Before:
```
In [1]:     from jax.experimental import export
   ...:     from jax import core
   ...:     a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])

In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```

After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-15 18:07:57 +01:00
George Necula
18698a1f19 [shape_poly] Add support for jnp.split 2024-02-15 14:43:41 +01:00
George Necula
ed735608b5 [shape_poly] Improve the symbolic expressions pretty-printer and parser.
Now we allow parsing: "+ a", "-a ", "-b + a".
Also we print "- a" instead of "-1*a".
2024-02-14 12:03:42 +02:00
George Necula
a650f6c03b [shape_poly] Improve the handling of equality external constraints.
Previously, when we had the constraint `f == e` we would only
replace `f` when it appears as the whole term. Now, we also
handle `f * f1 * f2` and we rewrite it to `e * f1 * f2`.
2024-02-13 09:03:03 +02:00
George Necula
202bcd372b [shape_poly] Performance improvements for symbolic dimension manipulations (step 1)
This adds two improvements:

  * Change the computation of the `_size` attribute for expressions, to be more
   sensitive to some of the integer coefficients, e.g., `4*a` is now structurally
   larger than `3*a`. Similarly with `a**4` and `a**3`. This allows the
   `_syntactic_cmp` to short circuit most of the comparison. This operation
   is used a lot when sorting and normalizing the representation of symbolic
   expressions.
  * Change the caching of the bounds computation. Now we store in the cache
   the precision with which we computed the previous bounds, which allows
   better reuse of the cache.

On a microbenchmark, this resulted in a reduction by 30% (before 3210 and after 2145)
of the number of calls to `bounds_for_sorted_terms` and a 10% reduction in the
total time spent in `bounds`:

After:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile          2307348 function calls (2260293 primitive calls) in 0.962 seconds
       1    0.000    0.000    0.969    0.969 shape_poly_test.py:1580(test_constraints_for_profile)
       1    0.000    0.000    0.234    0.234 shape_poly_test.py:1583(f)
     320    0.000    0.000    0.095    0.000 _shape_poly_decision.py:41(bounds_decision)
 425/280    0.001    0.000    0.094    0.000 _shape_poly_decision.py:234(bounds)
  513/51    0.002    0.000    0.091    0.002 _shape_poly_decision.py:260(_bounds_for_sorted_terms)
1230/135    0.001    0.000    0.081    0.001 _shape_poly_decision.py:330(add_implicit_constraints)
     250    0.000    0.000    0.076    0.000 _shape_poly.py:784(__ge__)
     250    0.000    0.000    0.076    0.000 _shape_poly.py:1077(_geq_decision)
 381/289    0.001    0.000    0.069    0.000 _shape_poly_decision.py:102(combine_and_add_constraint)
     695    0.001    0.000    0.065    0.000 _shape_poly.py:1673(_evaluate_multiply)
3572/766    0.002    0.000    0.051    0.000 _shape_poly.py:637(__str__)
```

Before:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile          3486289 function calls (3318484 primitive calls) in 1.240 seconds
       1    0.000    0.000    1.247    1.247 shape_poly_test.py:1569(test_constraints_for_profile)
 992/320    0.001    0.000    0.424    0.001 _shape_poly_decision.py:269(bounds)
3210/280    0.008    0.000    0.423    0.002 _shape_poly_decision.py:292(_bounds_for_sorted_terms)
     250    0.000    0.000    0.400    0.002 _shape_poly.py:783(__ge__)
     250    0.000    0.000    0.399    0.002 _shape_poly_decision.py:39(geq_decision)
```
2024-02-12 18:30:25 +02:00
George Necula
983bb32ae6 [shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.

See more details in the README.md changes.
2024-02-08 10:09:47 +01:00
Jake VanderPlas
84ee045f55 [key reuse] handle polymorphic shapes in slice 2024-01-29 13:59:44 -08:00
George Necula
e20afac46a [shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".

Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:

  * if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
    eliminate "a" and infer the derived constraint "b + c >= 0".
  * the lower bound of "a + c", in presence of a constraint "a >= b"
    it greater-or-equal to "b + c".

The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.

This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.

The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.

With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.

We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-29 17:26:35 +02:00
George Necula
0bd511d621 [shape_poly] Add more tests for reasoning about inequalities.
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as

   _expect(best=v1, current=v2)

to document that the current logic produces `v2` but the best value
we can hope for is `v1`.

This PR also adds more support for profiling tests.
2024-01-24 09:57:49 +01:00
George Necula
24201ef922 [shape_poly] Add support for symbolic constraints on dimension variables
Until now all the reasoning about symbolic dimensions was
done with the implicit assumption that the dimension variables
range over strictly positive integers. Here we allow the
user to specify stronger constraints, so that they can be
used in the reasoning about inequalities of symbolic dimensions.
These explicit constraints are checked at compilation time, when
the shapes are known.

This adds significant power to the implementation of
shape polymorphism, and in particular it adds an
escape hatch for when in the past users saw
inconclusive comparison exceptions.

See more details in the README.md in this PR.
2024-01-23 09:53:03 +01:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
George Necula
a1286d0021 [shape_poly] Improve core.max_dim and core.min_dim
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.

Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
2024-01-15 15:10:28 +02:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
George Necula
0967a797e8 [shape_poly] Protect shape_poly: rename to _shape_poly.py.
The public APIs can be accessed through `jax.experimental.export`.

The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.

PiperOrigin-RevId: 598119703
2024-01-13 02:05:11 -08:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
George Necula
b7f82e8cad [shape_poly] Improve the lexicographic ordering of symbolic expressions
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00