34 Commits

Author SHA1 Message Date
George Necula
30ddc400b8 [shape_poly] Fix handling of stride_in_dim with symbolic stride.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
2024-02-19 12:36:26 +01:00
George Necula
bb57fb71e2 [shape_poly] Performance improvements for symbolic dimension manipulations (step 3)
We make the following improvements:

  * Add a `linear_combination` function to use for computing
    linear combinations fo symbolic expressions. E.g, `a - b` used
    to involve 2 operations: "-1 * b" and "a + -1*b".
  * Change the representation of terms (_DimMon) from a dictionary
    mapping factors (_DimAtom) to exponents, into a sorted tuple of
    pairs (factor, exponent). This is worthwhile because in almost
    all cases a term contains a single factor. Everywhere we used
    `term.items()` now we use `term._factors`.
  * Make the computation of `._hash` lazy. Previously, we used dictionaries
    heavily for symbolic expressions and we always needed the hash value,
    now we use dictionaries less.
  * Replace `t.degree` with `t.is_constant`.
  * Add `__slots__` to the representation of symbolic expressions

Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`

After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
2024-02-16 17:33:34 +01:00
George Necula
eb9caf0d16 [shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:

  * Cache the state of the decision procedure after we process the explicit
    constraints, and reuse it for new decisions.
  * Rationalize the usage of add_implicit_constraints. We used to call it
    conservatively, too often. Now we call it only once for each explicit constraint,
    and once for each bounds decision we make. Then, in the add_implicit_constraints
    we call it recursively when we encounter new sub-expressions.
  * Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
    since we should only need it for reporting error messages.

This speeds up inequality reasoning:

Before:
```
In [1]:     from jax.experimental import export
   ...:     from jax import core
   ...:     a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])

In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```

After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-15 18:07:57 +01:00
George Necula
18698a1f19 [shape_poly] Add support for jnp.split 2024-02-15 14:43:41 +01:00
George Necula
ed735608b5 [shape_poly] Improve the symbolic expressions pretty-printer and parser.
Now we allow parsing: "+ a", "-a ", "-b + a".
Also we print "- a" instead of "-1*a".
2024-02-14 12:03:42 +02:00
George Necula
a650f6c03b [shape_poly] Improve the handling of equality external constraints.
Previously, when we had the constraint `f == e` we would only
replace `f` when it appears as the whole term. Now, we also
handle `f * f1 * f2` and we rewrite it to `e * f1 * f2`.
2024-02-13 09:03:03 +02:00
George Necula
202bcd372b [shape_poly] Performance improvements for symbolic dimension manipulations (step 1)
This adds two improvements:

  * Change the computation of the `_size` attribute for expressions, to be more
   sensitive to some of the integer coefficients, e.g., `4*a` is now structurally
   larger than `3*a`. Similarly with `a**4` and `a**3`. This allows the
   `_syntactic_cmp` to short circuit most of the comparison. This operation
   is used a lot when sorting and normalizing the representation of symbolic
   expressions.
  * Change the caching of the bounds computation. Now we store in the cache
   the precision with which we computed the previous bounds, which allows
   better reuse of the cache.

On a microbenchmark, this resulted in a reduction by 30% (before 3210 and after 2145)
of the number of calls to `bounds_for_sorted_terms` and a 10% reduction in the
total time spent in `bounds`:

After:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile          2307348 function calls (2260293 primitive calls) in 0.962 seconds
       1    0.000    0.000    0.969    0.969 shape_poly_test.py:1580(test_constraints_for_profile)
       1    0.000    0.000    0.234    0.234 shape_poly_test.py:1583(f)
     320    0.000    0.000    0.095    0.000 _shape_poly_decision.py:41(bounds_decision)
 425/280    0.001    0.000    0.094    0.000 _shape_poly_decision.py:234(bounds)
  513/51    0.002    0.000    0.091    0.002 _shape_poly_decision.py:260(_bounds_for_sorted_terms)
1230/135    0.001    0.000    0.081    0.001 _shape_poly_decision.py:330(add_implicit_constraints)
     250    0.000    0.000    0.076    0.000 _shape_poly.py:784(__ge__)
     250    0.000    0.000    0.076    0.000 _shape_poly.py:1077(_geq_decision)
 381/289    0.001    0.000    0.069    0.000 _shape_poly_decision.py:102(combine_and_add_constraint)
     695    0.001    0.000    0.065    0.000 _shape_poly.py:1673(_evaluate_multiply)
3572/766    0.002    0.000    0.051    0.000 _shape_poly.py:637(__str__)
```

Before:
```
tests/shape_poly_test.py::ShapePolyTest::test_constraints_for_profile          3486289 function calls (3318484 primitive calls) in 1.240 seconds
       1    0.000    0.000    1.247    1.247 shape_poly_test.py:1569(test_constraints_for_profile)
 992/320    0.001    0.000    0.424    0.001 _shape_poly_decision.py:269(bounds)
3210/280    0.008    0.000    0.423    0.002 _shape_poly_decision.py:292(_bounds_for_sorted_terms)
     250    0.000    0.000    0.400    0.002 _shape_poly.py:783(__ge__)
     250    0.000    0.000    0.399    0.002 _shape_poly_decision.py:39(geq_decision)
```
2024-02-12 18:30:25 +02:00
George Necula
983bb32ae6 [shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.

See more details in the README.md changes.
2024-02-08 10:09:47 +01:00
Jake VanderPlas
84ee045f55 [key reuse] handle polymorphic shapes in slice 2024-01-29 13:59:44 -08:00
George Necula
e20afac46a [shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".

Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:

  * if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
    eliminate "a" and infer the derived constraint "b + c >= 0".
  * the lower bound of "a + c", in presence of a constraint "a >= b"
    it greater-or-equal to "b + c".

The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.

This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.

The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.

With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.

We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-29 17:26:35 +02:00
George Necula
0bd511d621 [shape_poly] Add more tests for reasoning about inequalities.
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as

   _expect(best=v1, current=v2)

to document that the current logic produces `v2` but the best value
we can hope for is `v1`.

This PR also adds more support for profiling tests.
2024-01-24 09:57:49 +01:00
George Necula
24201ef922 [shape_poly] Add support for symbolic constraints on dimension variables
Until now all the reasoning about symbolic dimensions was
done with the implicit assumption that the dimension variables
range over strictly positive integers. Here we allow the
user to specify stronger constraints, so that they can be
used in the reasoning about inequalities of symbolic dimensions.
These explicit constraints are checked at compilation time, when
the shapes are known.

This adds significant power to the implementation of
shape polymorphism, and in particular it adds an
escape hatch for when in the past users saw
inconclusive comparison exceptions.

See more details in the README.md in this PR.
2024-01-23 09:53:03 +01:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
George Necula
a1286d0021 [shape_poly] Improve core.max_dim and core.min_dim
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.

Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
2024-01-15 15:10:28 +02:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
George Necula
0967a797e8 [shape_poly] Protect shape_poly: rename to _shape_poly.py.
The public APIs can be accessed through `jax.experimental.export`.

The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.

PiperOrigin-RevId: 598119703
2024-01-13 02:05:11 -08:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
George Necula
b7f82e8cad [shape_poly] Improve the lexicographic ordering of symbolic expressions
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
George Necula
2b4177d35f [shape_poly] Improve error messages for shape polymorphism
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
2024-01-08 05:52:06 +02:00
George Necula
2ca7b31388 [shape_poly] Cache the hash values for symbolic expressions
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.

We also cache the sorting of the monomials in a symbolic expression.
2024-01-07 06:50:18 +02:00
George Necula
f1c87e0176 [shape_poly] Fixes for the lexicographic ordering of monomials.
There were several bugs in the ordering of atoms and
monomials. The ordering for atoms and moomials is used
for sorting, and the __eq__ is also used for hashing.

One bug was that the ordering of atoms sometimes used
the `id` ordering. Another (performance) bug was that
the __eq__ for atoms used the (semantic) __eq__ for
DimExpr. The latter is expensive to compute, but for
sorting all we need is a syntactic comparison.

We introduce a `_syntactic_cmp` method for atoms,
monomials and expressions and we use it exclusively
for the ordering of atoms and monomials.

We also clean up printing and add tests for ordering and
pretty printing. Now we print monomial in "decreasing" order.
This is a change from before, in the sense that "a + b" is
printed as "b + a".
2024-01-07 06:21:55 +02:00
George Necula
30bc5a2a5f [shape_poly] Update the jax.ops.segment{max|...} to with with shape polymorphism
The fix is very small, just had to check how we check for cases when tracers
are passed as num_segments. We add tests.
2023-12-19 12:02:39 +02:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
George Necula
d2f62612d7 Fix bug in indexing with slices that overflow, and add tests.
This bug was introduced in #18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
George Necula
65fca0edf4 [shape_poly] Add heuristics for deciding >= 0
The rules for deciding inequalities of symbolic expressions
are incomplete. Here we add two heuristics that help decide
the bounds checking of indices computed for indexing with slices:

To decide whether an expression that contains `non_negative(e)` is
>= 0, it is sufficient to show that the expression is >=0 if we
replace the `non_negative(e)` with `0` and with `e`.

To decide whether `floordiv(e, k)` is >= 0, when `k >= 0`, it
is sufficient to show that `e` is >= 0.

These are sufficient for the bounds checking that JAX is doing
internally, but may not be for the cases when the user program
does index computations using those operators.

This enables us to re-enable the shape_poly indexing tests.
2023-12-01 13:55:42 +02:00
George Necula
2d1ce133bc [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
2023-12-01 08:40:07 +02:00
George Necula
c6afdfd8d6 [shape_poly] Simplify the API for processing polymorphic_shape specifications
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.

We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:

`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.

We also rename the `export.poly_specs` to `export.arg_specs`.
2023-11-28 12:45:59 +02:00
George Necula
301b9399c1 [shape_poly] Clean up the shape_poly_test.py
When we recently moved much of shape_poly_test out of jax2tf
we had to add a number of flags to avoid warnings (which are
errors in GitHub CI). Here we clean the tests so that we
can run them without the flags.

The most common problem was that tests were relying on
implicit rank promotion. We added a number of `jnp.expand_dims`
to fix the rank and let the implicit broadcasting do the rest.
2023-11-20 15:22:05 +02:00
George Necula
4fbf50dd60 [shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.

For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).

Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.

PiperOrigin-RevId: 583816243
2023-11-19 09:00:04 -08:00