Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.
Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
The public APIs can be accessed through `jax.experimental.export`.
The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.
PiperOrigin-RevId: 598119703
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.
The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```
This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.
In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.
PiperOrigin-RevId: 596563481
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.
We also cache the sorting of the monomials in a symbolic expression.
There were several bugs in the ordering of atoms and
monomials. The ordering for atoms and moomials is used
for sorting, and the __eq__ is also used for hashing.
One bug was that the ordering of atoms sometimes used
the `id` ordering. Another (performance) bug was that
the __eq__ for atoms used the (semantic) __eq__ for
DimExpr. The latter is expensive to compute, but for
sorting all we need is a syntactic comparison.
We introduce a `_syntactic_cmp` method for atoms,
monomials and expressions and we use it exclusively
for the ordering of atoms and monomials.
We also clean up printing and add tests for ordering and
pretty printing. Now we print monomial in "decreasing" order.
This is a change from before, in the sense that "a + b" is
printed as "b + a".
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.
Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.
Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611