As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.
The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.
PiperOrigin-RevId: 682312752
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.
First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.
Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.
Fixes: #23437Fixes: #23456
These started failing due to a compiler change internally at Google, but the tests themselves are buggy. It is not correct to compare an eigendecomposition for equality up to a tolerance, because the eigenvalues are sorted, and all it takes is a tiny perturbation to reorder the eigenvalues and eigenvectors, which leads to a result that looks very different.
PiperOrigin-RevId: 669346013
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.
One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).
PiperOrigin-RevId: 665829252
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.
PiperOrigin-RevId: 662024940
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.
While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.
And I discovered a bug in the implementation of `divmod(0, b)`.
The functionality comes from the jax.experimental.export
module, which will be deprecated.
The following APIs are introduced:
```
from jax import export
def f(...): ...
ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)
blob: bytearray = ex.serialize()
rehydrated: export.Export = export.deserialize(blob)
def caller(...):
... rehydrated.call(*args, **kwargs)
```
Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.
Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:
* Instead of `jax.experimental.export.call(exp)` we now write
`exp.call`
* The `jax.experimental.export.export` allowed the function
argument to be any Python callable and it would wrap it with
a `jax.jit`. This is not supported anymore by export, and instead
the user must use `jax.jit`.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.
We also add a backwards compatibility test.
PiperOrigin-RevId: 640557581
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.
Add a set of test cases for shape polymorphism for approx_top_k.
The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.
The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.
Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.
At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
We make the following improvements:
* Add a `linear_combination` function to use for computing
linear combinations fo symbolic expressions. E.g, `a - b` used
to involve 2 operations: "-1 * b" and "a + -1*b".
* Change the representation of terms (_DimMon) from a dictionary
mapping factors (_DimAtom) to exponents, into a sorted tuple of
pairs (factor, exponent). This is worthwhile because in almost
all cases a term contains a single factor. Everywhere we used
`term.items()` now we use `term._factors`.
* Make the computation of `._hash` lazy. Previously, we used dictionaries
heavily for symbolic expressions and we always needed the hash value,
now we use dictionaries less.
* Replace `t.degree` with `t.is_constant`.
* Add `__slots__` to the representation of symbolic expressions
Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`
After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
Previously, when we had the constraint `f == e` we would only
replace `f` when it appears as the whole term. Now, we also
handle `f * f1 * f2` and we rewrite it to `e * f1 * f2`.
This adds two improvements:
* Change the computation of the `_size` attribute for expressions, to be more
sensitive to some of the integer coefficients, e.g., `4*a` is now structurally
larger than `3*a`. Similarly with `a**4` and `a**3`. This allows the
`_syntactic_cmp` to short circuit most of the comparison. This operation
is used a lot when sorting and normalizing the representation of symbolic
expressions.
* Change the caching of the bounds computation. Now we store in the cache
the precision with which we computed the previous bounds, which allows
better reuse of the cache.
On a microbenchmark, this resulted in a reduction by 30% (before 3210 and after 2145)
of the number of calls to `bounds_for_sorted_terms` and a 10% reduction in the
total time spent in `bounds`:
After:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile 2307348 function calls (2260293 primitive calls) in 0.962 seconds
1 0.000 0.000 0.969 0.969 shape_poly_test.py:1580(test_constraints_for_profile)
1 0.000 0.000 0.234 0.234 shape_poly_test.py:1583(f)
320 0.000 0.000 0.095 0.000 _shape_poly_decision.py:41(bounds_decision)
425/280 0.001 0.000 0.094 0.000 _shape_poly_decision.py:234(bounds)
513/51 0.002 0.000 0.091 0.002 _shape_poly_decision.py:260(_bounds_for_sorted_terms)
1230/135 0.001 0.000 0.081 0.001 _shape_poly_decision.py:330(add_implicit_constraints)
250 0.000 0.000 0.076 0.000 _shape_poly.py:784(__ge__)
250 0.000 0.000 0.076 0.000 _shape_poly.py:1077(_geq_decision)
381/289 0.001 0.000 0.069 0.000 _shape_poly_decision.py:102(combine_and_add_constraint)
695 0.001 0.000 0.065 0.000 _shape_poly.py:1673(_evaluate_multiply)
3572/766 0.002 0.000 0.051 0.000 _shape_poly.py:637(__str__)
```
Before:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile 3486289 function calls (3318484 primitive calls) in 1.240 seconds
1 0.000 0.000 1.247 1.247 shape_poly_test.py:1569(test_constraints_for_profile)
992/320 0.001 0.000 0.424 0.001 _shape_poly_decision.py:269(bounds)
3210/280 0.008 0.000 0.423 0.002 _shape_poly_decision.py:292(_bounds_for_sorted_terms)
250 0.000 0.000 0.400 0.002 _shape_poly.py:783(__ge__)
250 0.000 0.000 0.399 0.002 _shape_poly_decision.py:39(geq_decision)
```
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as
_expect(best=v1, current=v2)
to document that the current logic produces `v2` but the best value
we can hope for is `v1`.
This PR also adds more support for profiling tests.
Until now all the reasoning about symbolic dimensions was
done with the implicit assumption that the dimension variables
range over strictly positive integers. Here we allow the
user to specify stronger constraints, so that they can be
used in the reasoning about inequalities of symbolic dimensions.
These explicit constraints are checked at compilation time, when
the shapes are known.
This adds significant power to the implementation of
shape polymorphism, and in particular it adds an
escape hatch for when in the past users saw
inconclusive comparison exceptions.
See more details in the README.md in this PR.
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.
Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
The public APIs can be accessed through `jax.experimental.export`.
The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.
PiperOrigin-RevId: 598119703
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.
The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```
This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.
In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.
PiperOrigin-RevId: 596563481