85 Commits

Author SHA1 Message Date
jax authors
2a828d5d6b Merge pull request #23467 from abhinavgoel95:patch-4
PiperOrigin-RevId: 685872800
2024-10-14 16:29:24 -07:00
George Necula
5fabd34e7e [jax2tf] Remove non-native serialization test from jax_to_ir_test
PiperOrigin-RevId: 683124315
2024-10-07 04:21:38 -07:00
Abhinav Goel
ea7ad92a86
fixes for mypy 2024-10-02 13:46:10 -07:00
Abhinav Goel
b5bb30329d
changed default 2024-09-25 15:43:59 -07:00
Abhinav Goel
b6553ba892
Update pgo_nsys_converter.py 2024-09-05 13:39:58 -07:00
Ilia Sergachev
e8730ddfe0 [NFC] Remove unused argument, fix help string. 2024-09-02 13:40:37 +02:00
Michael Goldfarb
d2b1ebd0aa Update pgo_nsys_converter.py to use the NVTX kern sum report when available. 2024-08-26 17:27:23 +00:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Sergei Lebedev
56745818a6 Added basic support for int2/uint2 dtypes to JAX
#21369

PiperOrigin-RevId: 649366888
2024-07-04 04:13:24 -07:00
jax authors
96cf5d53c8 Merge pull request #21916 from ROCm:ci_pjrt
PiperOrigin-RevId: 646793145
2024-06-26 02:43:21 -07:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
fb65ba4adf Add a config for using Clang on Windows.
PiperOrigin-RevId: 631112031
2024-05-06 10:39:28 -07:00
Sergei Lebedev
a13efc2815 Added int4 and uint4 to dtype-specific tests
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Jake VanderPlas
2eff1f0f3f Guard script against execution on import 2024-03-07 18:32:18 -08:00
Jake VanderPlas
6d67aa2242 Fix mypy errors 2024-03-07 18:18:04 -08:00
Jake VanderPlas
0e79f95fdb lint: fix unused import 2024-03-07 18:12:42 -08:00
jax authors
fa17dacbc0 Merge pull request #20042 from abhinavgoel95:patch-1
PiperOrigin-RevId: 613685424
2024-03-07 13:28:57 -08:00
Abhinav Goel
a7a9f85535
Added license information 2024-03-07 11:55:55 -08:00
Jake VanderPlas
d8a4ea42cc Use shorter error message for jax.tools.colab_tpu.setup_tpu() 2024-03-05 11:20:38 -08:00
Abhinav Goel
2480ca383e
respond to reviewer's comments 2024-03-04 11:42:01 -08:00
Abhinav Goel
5a732ad89f
Adding script to convert NVIDIA nsys profiles to pbtxt 2024-03-01 09:14:11 -08:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Peter Hawkins
caee898fd0 Fix jaxlib build failure after upstream MLIR Python binding changes.
https://github.com/llvm/llvm-project/pull/68853 changed the structure of
the upstream MLIR Python bindings, breaking the jaxlib build. Update our
build scripts to match.
2023-10-23 14:27:52 +00:00
Peter Hawkins
fa8159681d Clean up build_wheel.py and build_gpu_plugin_wheel.py.
* Use pathlib.Path object-oriented paths.
* Change copy_files() helper to copy many files in one call.
* Make copy_files() also make the output directory, if needed.
* Format file with pyink --pyink-indentation=2
2023-09-29 20:08:42 +00:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
02e43e3510 Import py_binary() rule from @rules_python.
PiperOrigin-RevId: 548129348
2023-07-14 08:09:07 -07:00
John Cater
db8716701f Migrate exec_tools back to tools.
PiperOrigin-RevId: 534549617
2023-05-23 14:00:34 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
George Necula
023bfa84c2 [jax2tf] Fix test that requires non-native serialization
PiperOrigin-RevId: 517803582
2023-03-19 12:16:06 -07:00
Yash Katariya
58e46b48e6 Prepare for jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510152471
2023-02-16 08:37:15 -08:00
Ashish Shenoy
f71a55c554 Rename tensorflow core target variable to tensorflow_core
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
Skye Wanderman-Milne
8ab158574d Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release 2023-02-07 15:45:28 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Skye Wanderman-Milne
3f4bd5f449 Updates for jax + jaxlib 0.4.2 release 2023-01-20 19:04:46 +00:00
Yash Katariya
c4d590b1b6 Update values for release 0.4.1
PiperOrigin-RevId: 494889744
2022-12-12 19:04:38 -08:00
Yash Katariya
0118f8d568 Prepare for jax and jaxlib 0.4.0 release
PiperOrigin-RevId: 493733609
2022-12-07 16:02:24 -08:00
Yash Katariya
a683186570 Use the 11/09 libtpu build for jaxlib release since that passes all the tests.
PiperOrigin-RevId: 488543322
2022-11-14 20:37:41 -08:00
Yash Katariya
f36084acd3 Update the values for jaxlib release (again)
PiperOrigin-RevId: 488522992
2022-11-14 18:31:08 -08:00
Yash Katariya
0da02dd41c Update the values needed for a jaxlib release
PiperOrigin-RevId: 488508360
2022-11-14 17:08:59 -08:00
Yash Katariya
7600cc8a8e Make jax.Array default to False for external colab.
PiperOrigin-RevId: 488360010
2022-11-14 07:28:00 -08:00