The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
We make the following improvements:
* Add a `linear_combination` function to use for computing
linear combinations fo symbolic expressions. E.g, `a - b` used
to involve 2 operations: "-1 * b" and "a + -1*b".
* Change the representation of terms (_DimMon) from a dictionary
mapping factors (_DimAtom) to exponents, into a sorted tuple of
pairs (factor, exponent). This is worthwhile because in almost
all cases a term contains a single factor. Everywhere we used
`term.items()` now we use `term._factors`.
* Make the computation of `._hash` lazy. Previously, we used dictionaries
heavily for symbolic expressions and we always needed the hash value,
now we use dictionaries less.
* Replace `t.degree` with `t.is_constant`.
* Add `__slots__` to the representation of symbolic expressions
Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`
After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
Previously, when we had the constraint `f == e` we would only
replace `f` when it appears as the whole term. Now, we also
handle `f * f1 * f2` and we rewrite it to `e * f1 * f2`.
This adds two improvements:
* Change the computation of the `_size` attribute for expressions, to be more
sensitive to some of the integer coefficients, e.g., `4*a` is now structurally
larger than `3*a`. Similarly with `a**4` and `a**3`. This allows the
`_syntactic_cmp` to short circuit most of the comparison. This operation
is used a lot when sorting and normalizing the representation of symbolic
expressions.
* Change the caching of the bounds computation. Now we store in the cache
the precision with which we computed the previous bounds, which allows
better reuse of the cache.
On a microbenchmark, this resulted in a reduction by 30% (before 3210 and after 2145)
of the number of calls to `bounds_for_sorted_terms` and a 10% reduction in the
total time spent in `bounds`:
After:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile 2307348 function calls (2260293 primitive calls) in 0.962 seconds
1 0.000 0.000 0.969 0.969 shape_poly_test.py:1580(test_constraints_for_profile)
1 0.000 0.000 0.234 0.234 shape_poly_test.py:1583(f)
320 0.000 0.000 0.095 0.000 _shape_poly_decision.py:41(bounds_decision)
425/280 0.001 0.000 0.094 0.000 _shape_poly_decision.py:234(bounds)
513/51 0.002 0.000 0.091 0.002 _shape_poly_decision.py:260(_bounds_for_sorted_terms)
1230/135 0.001 0.000 0.081 0.001 _shape_poly_decision.py:330(add_implicit_constraints)
250 0.000 0.000 0.076 0.000 _shape_poly.py:784(__ge__)
250 0.000 0.000 0.076 0.000 _shape_poly.py:1077(_geq_decision)
381/289 0.001 0.000 0.069 0.000 _shape_poly_decision.py:102(combine_and_add_constraint)
695 0.001 0.000 0.065 0.000 _shape_poly.py:1673(_evaluate_multiply)
3572/766 0.002 0.000 0.051 0.000 _shape_poly.py:637(__str__)
```
Before:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile 3486289 function calls (3318484 primitive calls) in 1.240 seconds
1 0.000 0.000 1.247 1.247 shape_poly_test.py:1569(test_constraints_for_profile)
992/320 0.001 0.000 0.424 0.001 _shape_poly_decision.py:269(bounds)
3210/280 0.008 0.000 0.423 0.002 _shape_poly_decision.py:292(_bounds_for_sorted_terms)
250 0.000 0.000 0.400 0.002 _shape_poly.py:783(__ge__)
250 0.000 0.000 0.399 0.002 _shape_poly_decision.py:39(geq_decision)
```
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as
_expect(best=v1, current=v2)
to document that the current logic produces `v2` but the best value
we can hope for is `v1`.
This PR also adds more support for profiling tests.
Until now all the reasoning about symbolic dimensions was
done with the implicit assumption that the dimension variables
range over strictly positive integers. Here we allow the
user to specify stronger constraints, so that they can be
used in the reasoning about inequalities of symbolic dimensions.
These explicit constraints are checked at compilation time, when
the shapes are known.
This adds significant power to the implementation of
shape polymorphism, and in particular it adds an
escape hatch for when in the past users saw
inconclusive comparison exceptions.
See more details in the README.md in this PR.
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.
Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
The public APIs can be accessed through `jax.experimental.export`.
The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.
PiperOrigin-RevId: 598119703
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.
The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```
This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.
In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.
PiperOrigin-RevId: 596563481
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.
We also cache the sorting of the monomials in a symbolic expression.
There were several bugs in the ordering of atoms and
monomials. The ordering for atoms and moomials is used
for sorting, and the __eq__ is also used for hashing.
One bug was that the ordering of atoms sometimes used
the `id` ordering. Another (performance) bug was that
the __eq__ for atoms used the (semantic) __eq__ for
DimExpr. The latter is expensive to compute, but for
sorting all we need is a syntactic comparison.
We introduce a `_syntactic_cmp` method for atoms,
monomials and expressions and we use it exclusively
for the ordering of atoms and monomials.
We also clean up printing and add tests for ordering and
pretty printing. Now we print monomial in "decreasing" order.
This is a change from before, in the sense that "a + b" is
printed as "b + a".
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
The rules for deciding inequalities of symbolic expressions
are incomplete. Here we add two heuristics that help decide
the bounds checking of indices computed for indexing with slices:
To decide whether an expression that contains `non_negative(e)` is
>= 0, it is sufficient to show that the expression is >=0 if we
replace the `non_negative(e)` with `0` and with `e`.
To decide whether `floordiv(e, k)` is >= 0, when `k >= 0`, it
is sufficient to show that `e` is >= 0.
These are sufficient for the bounds checking that JAX is doing
internally, but may not be for the cases when the user program
does index computations using those operators.
This enables us to re-enable the shape_poly indexing tests.
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.
Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.
To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.
We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:
`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.
We also rename the `export.poly_specs` to `export.arg_specs`.
When we recently moved much of shape_poly_test out of jax2tf
we had to add a number of flags to avoid warnings (which are
errors in GitHub CI). Here we clean the tests so that we
can run them without the flags.
The most common problem was that tests were relying on
implicit rank promotion. We added a number of `jnp.expand_dims`
to fix the rank and let the implicit broadcasting do the rest.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.
For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).
Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.
PiperOrigin-RevId: 583816243