672 Commits

Author SHA1 Message Date
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Kevin Gleason
184e3a8800 Integrate StableHLO at openxla/stablehlo@ab709fe4
PiperOrigin-RevId: 589908773
2023-12-11 12:30:50 -08:00
Jevin Jiang
3651d4c4f5 [XLA:Mosaic] Support tpu.bitcast for i16, i8.
PiperOrigin-RevId: 589881484
2023-12-11 11:14:16 -08:00
Blake Hechtman
1cba270f14 [XLA:MOSAIC] implement downcast from s32 to s8 correctly
PiperOrigin-RevId: 589830744
2023-12-11 08:20:53 -08:00
Peter Hawkins
560187334a Add register_jax_dialects to jaxlib wheel.
Fixes build breakage.
2023-12-06 19:07:04 +00:00
Peter Hawkins
1c80b364d2 Remove stale reference to _site_initialize_0 in wheel build script. 2023-12-06 12:12:15 -05:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
Jevin Jiang
9d35b904e1 [XLA:Mosaic] Support expanding lane dim in shapecast: (..., 128) -> (..., m * 128) and handle relayout from (1, 128) to (8, 128) for more general cases.
PiperOrigin-RevId: 588024159
2023-12-05 04:25:10 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Peter Hawkins
a999120514 Improve error message when cudnn is not found.
We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (10a378f499/third_party/tsl/tsl/cuda/cudnn_stub.cc (L58)).

PiperOrigin-RevId: 587056454
2023-12-01 10:52:48 -08:00
Peter Hawkins
50c7223ed1 Fix Windows build failure.
The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per 9584f58344/mlir/include/mlir-c/Bindings/Python/Interop.h (L24)

PiperOrigin-RevId: 587049246
2023-12-01 10:31:53 -08:00
Shashank Viswanadha
bd46e5c960 Add nb::arg to nanobind definitions to generate better python annotations.
PiperOrigin-RevId: 586721759
2023-11-30 10:39:28 -08:00
Shashank Viswanadha
350b7c56b8 Add python stub files for jaxlib/cpu C++ Python extensions.
PiperOrigin-RevId: 585990748
2023-11-28 08:45:24 -08:00
Adam Paszke
ffbd632fb6 Add type annotations to avoid initializer list issues on macOS
Also remove the vector-avoiding specialization. For some reason
is_same<ssize_t, int64_t> evaluates to true on macOS, but then
the compiler complains that int64_t is a long long, while
ssize_t is only a long.
2023-11-27 18:02:50 +00:00
Tomás Longeri
08648150ab [Mosaic} C++ apply-vector-layout: add support for tpu.region
(already exists in Python)

PiperOrigin-RevId: 584599917
2023-11-22 05:35:29 -08:00
Tomás Longeri
4c9f2aca0c [Mosaic] C++ apply-vector-layout: don't skip unrecognized operations in applyLayoutOp
Although the TODO says to return failure, this is actually done at the end of the function (and this way we handle the case for ops without vector args).

PiperOrigin-RevId: 584575120
2023-11-22 03:25:07 -08:00
Tomás Longeri
8457b02a31 [Mosaic] C++ apply-vector-layout: fix unnecessarily setting out_layout in matmul rule (which never existed in Python)
PiperOrigin-RevId: 584568680
2023-11-22 02:55:16 -08:00
Tomás Longeri
f35ddc8c68 Fix bad cast in tpu_ext.cc
The argument to the cast is of type ssize_t. Mismatch between int64_t and ssize_t happens in Mac and causes build to fail:
`error: const_cast from 'const pybind11::ssize_t *' (aka 'const long *') to 'int64_t *' (aka 'long long *') is not allowed`

PiperOrigin-RevId: 584457599
2023-11-21 16:23:27 -08:00
Adam Paszke
038879248d Add a recently added Mosaic Python file to build_wheel.py
PiperOrigin-RevId: 584356541
2023-11-21 10:03:59 -08:00
Tomás Longeri
f12216908d [Mosaic] In Python bindings, fix getDefaultInsertionPoint and change import path for mlir.ir
PiperOrigin-RevId: 584258833
2023-11-21 02:10:20 -08:00
Tomás Longeri
f602fbe997 [Mosaic] Python bindings for VectorLayout, VRegDataBounds, assemble, disassemble, relayout and apply_layout_op
PiperOrigin-RevId: 584220887
2023-11-20 22:44:40 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
Sharad Vikram
f82a8309bd [Mosaic] Fix lint failure in apply_vector_layout
PiperOrigin-RevId: 584119457
2023-11-20 13:49:54 -08:00
Krasimir Georgiev
9287a6369d Integrate LLVM at llvm/llvm-project@9bdbb8226e
Updates LLVM usage to match
[9bdbb8226e70](https://github.com/llvm/llvm-project/commit/9bdbb8226e70)

PiperOrigin-RevId: 584091615
2023-11-20 12:01:57 -08:00
Peter Hawkins
41f0b336e3 Add minimum version checks for cublas and cusparse.
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)

Fixes https://github.com/google/jax/issues/8289

PiperOrigin-RevId: 583544218
2023-11-17 19:30:41 -08:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
George Necula
c1f54d447e Move back_compat_test_util.py to jax._src.internal_test_util.
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.

Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.

We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.

In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.

PiperOrigin-RevId: 583312085
2023-11-17 02:05:30 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
Jieying Luo
43732e3fd4 Change the definition of the config to run bazel test for cuda plugin to match //jax:build_jaxlib.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.

PiperOrigin-RevId: 583057751
2023-11-16 08:44:22 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
Jevin Jiang
8a64d9af40 [XLA:Mosaic] Support arbitrary aligned shape for tpu.bitcast and support bitcast with bitwidth change in element.
PiperOrigin-RevId: 582524212
2023-11-14 20:25:47 -08:00
Jieying Luo
ec21e04201 [PJRT C API] Rename the folder "plugins" to "jax_plugins".
With this change, existing plugin discovery mechanism can discover local plugins without pip install.

Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.

PiperOrigin-RevId: 582431300
2023-11-14 13:56:13 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
Peter Hawkins
cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
Tomás Longeri
1b79395d32 [Mosaic] Add missing MLIR_CAPI_EXPORTED to some C API functions introduced in cl/580405900 and cl/579355000
PiperOrigin-RevId: 581021172
2023-11-09 14:19:08 -08:00
George Necula
5001a21bad Move primitive_harness.py to jax._src.internal_test_util.test_harnesses.
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.

Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.

  * E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
  * E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
  * E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.

Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)

This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.

PiperOrigin-RevId: 581016785
2023-11-09 13:58:00 -08:00
Tomás Longeri
a23aac5566 [Mosaic][NFC] Remove static inline for functions in anonymous namespace in source file
PiperOrigin-RevId: 580974002
2023-11-09 11:32:20 -08:00
Peter Hawkins
3ee506d09a [Mosaic] Fix compilation failure in TPU dialect under MSVC.
The TPU MLIR dialect gets built on all platforms, so it has to compile on Windows.

Fixes https://github.com/google/jax/issues/18455

PiperOrigin-RevId: 580933582
2023-11-09 09:29:02 -08:00
Tomás Longeri
250486f13e [Mosaic] Expose C API for apply-vector-layout's assemble, disassemble, relayout and applyLayoutOp
PiperOrigin-RevId: 580405900
2023-11-07 22:32:02 -08:00
Tomás Longeri
52b01414c6 [Mosaic][NFC] C++ apply_vector_layout: Do not pass RewriteContext to functions that only need target_shape
Notably, this includes `assemble`, `disassemble` and `relayout`.

PiperOrigin-RevId: 580388074
2023-11-07 20:46:06 -08:00
Jevin Jiang
ecceac3a88 [XLA:Mosaic] Add support for arith.bitcast for both scalar and vector.
PiperOrigin-RevId: 580252662
2023-11-07 11:49:49 -08:00
Adam Paszke
e66f4e94c4 [Mosaic] Add support for extracting the first element of a vector as a scalar
PiperOrigin-RevId: 580169469
2023-11-07 07:20:48 -08:00
Tomás Longeri
c5d6df4557 [Mosaic] apply_vector_layout C++: Fix check in vector.broadcast rule
PiperOrigin-RevId: 580115242
2023-11-07 03:06:43 -08:00
Tomás Longeri
c186928a3e [Mosaic] Don't link CAPIIR into _tpu_ext, link into jaxlib_mlir_capi_shared_library instead
PiperOrigin-RevId: 579881376
2023-11-06 10:14:43 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Adam Paszke
a90cfc6466 [Mosaic] Add a missing MLIR dialect search prefix
Mosaic is broken by a recent MLIR change if we don't do it.

PiperOrigin-RevId: 579851049
2023-11-06 08:28:21 -08:00
Tomás Longeri
1c1dd7c8c7 [Mosaic] Expose C API for VectorLayout, VRegDataBounds
This is in preparation for Python bindings

PiperOrigin-RevId: 579355000
2023-11-03 18:24:16 -07:00