This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).
PiperOrigin-RevId: 597546106
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.
PiperOrigin-RevId: 597538913
This replicates the optimization we already apply while truncating floating point types.
Also, the heuristic used previously didn't include the tpu.matmul op, which could have
led to some performance degradation.
PiperOrigin-RevId: 597514672
This is part 1 of a change that enables allocating arrays of semaphores. It does
not add any new public facing functionality and only changes how semaphores
are represented in Mosaic.
PiperOrigin-RevId: 595848688
Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds).
PiperOrigin-RevId: 595725529
This lets us break a dependency on standard MLIR dialects while serializing
the program into HLO. The scheme is simple: we make a lightweight lazy fork
of existing dialects by mangling the dialect name and otherwise keeping the
structure of the ops identical. This keeps serialization and deserialization
simple, for as long as the upstream dialects don't change much. If they do,
we have to increment our version counter and write rules that update the IR
structure.
Note that this scheme only protects us from changes such as changing the
attributes annotating the ops (renaming, etc.). However, it doesn't protect
us from the attributes defined by a dialect from changing. Still, as far as
I can tell, the only attributes we depend on are enums (which are simply
plain integer attributes, so we can remap their values) and affine maps
(that are unlikely to change much, I hope).
This does not actually wire up the pass yet, as we are currently reorganizing
the Python/C++ boundary significantly. The integration should be completed
once that works is done.
PiperOrigin-RevId: 595128374
This lets us easily catch things such as out-of-bounds loads
or reference slices (leading to OOB DMAs or loads downstream).
PiperOrigin-RevId: 595072511
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.
PiperOrigin-RevId: 587323808
We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (10a378f499/third_party/tsl/tsl/cuda/cudnn_stub.cc (L58)).
PiperOrigin-RevId: 587056454
Also remove the vector-avoiding specialization. For some reason
is_same<ssize_t, int64_t> evaluates to true on macOS, but then
the compiler complains that int64_t is a long long, while
ssize_t is only a long.
Although the TODO says to return failure, this is actually done at the end of the function (and this way we handle the case for ops without vector args).
PiperOrigin-RevId: 584575120
The argument to the cast is of type ssize_t. Mismatch between int64_t and ssize_t happens in Mac and causes build to fail:
`error: const_cast from 'const pybind11::ssize_t *' (aka 'const long *') to 'int64_t *' (aka 'long long *') is not allowed`
PiperOrigin-RevId: 584457599