In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
mosaic_grid_mapping sometimes extend the scalar prefetch operands with arrays
describing the mesh, throwing off the math in the recent patch.
PiperOrigin-RevId: 596596640
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```
This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.
In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.
PiperOrigin-RevId: 596563481
Use core.dilate_dim, it is clearer and especially when using
shape polymoprhism it will result in better error messages.
Replace the use of np.maximum with core.max_dim, because the
latter will result in fewer errors in presence of shape polymorphism
(the core.max_dim is deferring the actual computation to shape
refinement time).
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
Equality is used heavily for symbolic expressions because we use them
as dictionary keys or in sets. Previously, we used a more complete
and more expensive form of equality where we attempted to prove that
"e1 - e2 >= 0" and "e1 - e2 <= 0". This is an overkill and none
of the tests we have so far rely on this power. Now we just
normalize "e1 - e2" and if it reduces syntactically to an integer
we check if the integer is 0. If the difference does not reduce
to an integer we say that the expressions are disequal.
This may possibly change user-visible behavior when it depends
on the outcome of equality comparisons of symbolic dimensions
in presence of shape polymorphism.
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.
We also cache the sorting of the monomials in a symbolic expression.