702 Commits

Author SHA1 Message Date
jax authors
abe820c1e8 Merge pull request #19377 from superbobry:main
PiperOrigin-RevId: 598866324
2024-01-16 09:33:29 -08:00
Sergei Lebedev
1e9f96a574 Include Triton files into the jaxlib wheel
This PR is based on #19368.
2024-01-16 15:28:12 +00:00
Sergei Lebedev
af49b01e1f Migrated a subset of triton.compat to directly use IR builders
PiperOrigin-RevId: 598826331
2024-01-16 06:46:08 -08:00
jax authors
1db167489d Merge pull request #19328 from hawkinsp:buildtest
PiperOrigin-RevId: 597895981
2024-01-12 11:15:45 -08:00
Sergei Lebedev
87301aa737 Fixed the default api_version= in register_custom_call_target()
PiperOrigin-RevId: 597834961
2024-01-12 07:27:49 -08:00
Sergei Lebedev
935db25a2a Added an api_version field to PJRT_Gpu_Register_Custom_Call*
This allows using the correct registration API for both legacy (untyped) and
new (typed) XLA FFI custom calls.

PiperOrigin-RevId: 597818106
2024-01-12 05:47:26 -08:00
Adam Paszke
f625fb69da [Mosaic] Add support for tile-aligned dynamic offsets in loads, stores and ref slices
PiperOrigin-RevId: 597798116
2024-01-12 03:42:58 -08:00
Peter Hawkins
dedd69f323 Add a bazel test that verifies that the jaxlib wheel builds. 2024-01-11 23:22:17 +00:00
Sergei Lebedev
5b7a0d9c91 Pallas now uses MLIR Python builders to lower to Triton IR
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.

PiperOrigin-RevId: 597640756
2024-01-11 13:33:26 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Adam Paszke
8f771b4211 [Mosaic] Simplify the handling of dynamic indices in vector.load and store
This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).

PiperOrigin-RevId: 597546106
2024-01-11 07:08:07 -08:00
Adam Paszke
ce00e10d9b [Pallas][Mosaic] Add support for nontrivial semaphore memrefs
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.

PiperOrigin-RevId: 597538913
2024-01-11 06:33:49 -08:00
Peter Hawkins
858fd52ac0 Fix jaxlib wheel build after removal of mosaic python files.
PiperOrigin-RevId: 597536465
2024-01-11 06:21:07 -08:00
Adam Paszke
57506b50c5 [Mosaic] Make sure to infer native tiling for results of TruncIOp that are fed into matmuls
This replicates the optimization we already apply while truncating floating point types.
Also, the heuristic used previously didn't include the tpu.matmul op, which could have
led to some performance degradation.

PiperOrigin-RevId: 597514672
2024-01-11 04:15:36 -08:00
Jevin Jiang
57f05592dd [XLA:Mosaic] Support inputs and outputs in scf::ForOp and add tpu::AssumeLayoutOp to work around block argument as operand.
PiperOrigin-RevId: 597353171
2024-01-10 14:15:05 -08:00
Sergei Lebedev
6174145386 Removed the Triton dependency from the BUILD file
PiperOrigin-RevId: 597336551
2024-01-10 13:17:47 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
Tomás Longeri
92bcd3f902 [Mosaic] apply_vector_layout: Copy docstring for VectorLayout in Python to C++
This is in preparation for removing the Python version.

PiperOrigin-RevId: 597015430
2024-01-09 12:43:59 -08:00
Sergei Lebedev
f219482212 The Triton MLIR bindings now include auto-generated wrappers for enums
PiperOrigin-RevId: 596873541
2024-01-09 03:00:47 -08:00
Tomás Longeri
88542f0e56 [Mosaic] Run C++ passes from within custom_call_emitter.cc
PiperOrigin-RevId: 596464480
2024-01-07 20:25:45 -08:00
Sharad Vikram
d6a47230b2 [Pallas/Mosaic] Change Mosaic semaphores to be MemRef types
This is part 1 of a change that enables allocating arrays of semaphores. It does
not add any new public facing functionality and only changes how semaphores
are represented in Mosaic.

PiperOrigin-RevId: 595848688
2024-01-04 17:56:30 -08:00
Jake VanderPlas
326d1d27ef jaxlib: avoid external build-time dependency on ml_dtypes
Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds).

PiperOrigin-RevId: 595725529
2024-01-04 09:26:05 -08:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Sergei Lebedev
e6c890171b Generate Python bindings for the Triton MLIR dialect
The bindings are not yet included in the jaxlib wheel. I will do that in a
follow up PR.

PiperOrigin-RevId: 595174466
2024-01-02 11:55:05 -08:00
Adam Paszke
a678015c74 Implement a stable serialization API for Mosaic
This lets us break a dependency on standard MLIR dialects while serializing
the program into HLO. The scheme is simple: we make a lightweight lazy fork
of existing dialects by mangling the dialect name and otherwise keeping the
structure of the ops identical. This keeps serialization and deserialization
simple, for as long as the upstream dialects don't change much. If they do,
we have to increment our version counter and write rules that update the IR
structure.

Note that this scheme only protects us from changes such as changing the
attributes annotating the ops (renaming, etc.). However, it doesn't protect
us from the attributes defined by a dialect from changing. Still, as far as
I can tell, the only attributes we depend on are enums (which are simply
plain integer attributes, so we can remap their values) and affine maps
(that are unlikely to change much, I hope).

This does not actually wire up the pass yet, as we are currently reorganizing
the Python/C++ boundary significantly. The integration should be completed
once that works is done.

PiperOrigin-RevId: 595128374
2024-01-02 08:51:56 -08:00
Adam Paszke
0419e014f1 [Mosaic] Add a pass to check operation invariants on-device
This lets us easily catch things such as out-of-bounds loads
or reference slices (leading to OOB DMAs or loads downstream).

PiperOrigin-RevId: 595072511
2024-01-02 03:19:35 -08:00
Christian Sigg
c83fd971a0 Fix jax mlir python dependency build after 537b2aa264
PiperOrigin-RevId: 593370604
2023-12-23 21:02:29 -08:00
Dmitri Gribenko
35b8fdc3b2 Integrate LLVM at llvm/llvm-project@7022a24771
Updates LLVM usage to match
[7022a24771c8](https://github.com/llvm/llvm-project/commit/7022a24771c8)

PiperOrigin-RevId: 592546932
2023-12-20 06:52:01 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Kevin Gleason
184e3a8800 Integrate StableHLO at openxla/stablehlo@ab709fe4
PiperOrigin-RevId: 589908773
2023-12-11 12:30:50 -08:00
Jevin Jiang
3651d4c4f5 [XLA:Mosaic] Support tpu.bitcast for i16, i8.
PiperOrigin-RevId: 589881484
2023-12-11 11:14:16 -08:00
Blake Hechtman
1cba270f14 [XLA:MOSAIC] implement downcast from s32 to s8 correctly
PiperOrigin-RevId: 589830744
2023-12-11 08:20:53 -08:00
Peter Hawkins
560187334a Add register_jax_dialects to jaxlib wheel.
Fixes build breakage.
2023-12-06 19:07:04 +00:00
Peter Hawkins
1c80b364d2 Remove stale reference to _site_initialize_0 in wheel build script. 2023-12-06 12:12:15 -05:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
Jevin Jiang
9d35b904e1 [XLA:Mosaic] Support expanding lane dim in shapecast: (..., 128) -> (..., m * 128) and handle relayout from (1, 128) to (8, 128) for more general cases.
PiperOrigin-RevId: 588024159
2023-12-05 04:25:10 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Peter Hawkins
a999120514 Improve error message when cudnn is not found.
We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (10a378f499/third_party/tsl/tsl/cuda/cudnn_stub.cc (L58)).

PiperOrigin-RevId: 587056454
2023-12-01 10:52:48 -08:00
Peter Hawkins
50c7223ed1 Fix Windows build failure.
The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per 9584f58344/mlir/include/mlir-c/Bindings/Python/Interop.h (L24)

PiperOrigin-RevId: 587049246
2023-12-01 10:31:53 -08:00
Shashank Viswanadha
bd46e5c960 Add nb::arg to nanobind definitions to generate better python annotations.
PiperOrigin-RevId: 586721759
2023-11-30 10:39:28 -08:00
Shashank Viswanadha
350b7c56b8 Add python stub files for jaxlib/cpu C++ Python extensions.
PiperOrigin-RevId: 585990748
2023-11-28 08:45:24 -08:00
Adam Paszke
ffbd632fb6 Add type annotations to avoid initializer list issues on macOS
Also remove the vector-avoiding specialization. For some reason
is_same<ssize_t, int64_t> evaluates to true on macOS, but then
the compiler complains that int64_t is a long long, while
ssize_t is only a long.
2023-11-27 18:02:50 +00:00
Tomás Longeri
08648150ab [Mosaic} C++ apply-vector-layout: add support for tpu.region
(already exists in Python)

PiperOrigin-RevId: 584599917
2023-11-22 05:35:29 -08:00
Tomás Longeri
4c9f2aca0c [Mosaic] C++ apply-vector-layout: don't skip unrecognized operations in applyLayoutOp
Although the TODO says to return failure, this is actually done at the end of the function (and this way we handle the case for ops without vector args).

PiperOrigin-RevId: 584575120
2023-11-22 03:25:07 -08:00
Tomás Longeri
8457b02a31 [Mosaic] C++ apply-vector-layout: fix unnecessarily setting out_layout in matmul rule (which never existed in Python)
PiperOrigin-RevId: 584568680
2023-11-22 02:55:16 -08:00
Tomás Longeri
f35ddc8c68 Fix bad cast in tpu_ext.cc
The argument to the cast is of type ssize_t. Mismatch between int64_t and ssize_t happens in Mac and causes build to fail:
`error: const_cast from 'const pybind11::ssize_t *' (aka 'const long *') to 'int64_t *' (aka 'long long *') is not allowed`

PiperOrigin-RevId: 584457599
2023-11-21 16:23:27 -08:00
Adam Paszke
038879248d Add a recently added Mosaic Python file to build_wheel.py
PiperOrigin-RevId: 584356541
2023-11-21 10:03:59 -08:00