From 1aca76fc133a36903cf154a499b1cc319d340a48 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Mar 2025 08:29:45 -0700 Subject: [PATCH 01/28] Update `:build_jaxlib` flag to control whether we should add `py_import` dependencies to the test targets. This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only. There are three options for running the tests: 1) `build_jaxlib=true`: the tests depend on JAX targets. 2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder. 3) `build_jaxlib=wheel`: the tests depend on the py_import targets. PiperOrigin-RevId: 735765819 --- build/gpu-test-requirements.txt | 13 +++++++ build/requirements.in | 1 + build/requirements_lock_3_10.txt | 58 +++++++++++++++++++++++++++++ build/requirements_lock_3_11.txt | 58 +++++++++++++++++++++++++++++ build/requirements_lock_3_12.txt | 58 +++++++++++++++++++++++++++++ build/requirements_lock_3_13.txt | 58 +++++++++++++++++++++++++++++ build/requirements_lock_3_13_ft.txt | 58 +++++++++++++++++++++++++++++ jax/BUILD | 21 ++++++++--- jax_plugins/cuda/BUILD.bazel | 22 +++++++++-- jaxlib/jax.bzl | 14 ++++++- jaxlib/tools/BUILD.bazel | 33 ++++++++++++++++ 11 files changed, 384 insertions(+), 10 deletions(-) create mode 100644 build/gpu-test-requirements.txt diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt new file mode 100644 index 000000000..54f5d2ab3 --- /dev/null +++ b/build/gpu-test-requirements.txt @@ -0,0 +1,13 @@ +# NVIDIA CUDA dependencies +# Note that the wheels are downloaded only when the targets in bazel command +# contain dependencies on these wheels. +nvidia-cublas-cu12>=12.1.3.1 +nvidia-cuda-cupti-cu12>=12.1.105 +nvidia-cuda-nvcc-cu12>=12.6.85 +nvidia-cuda-runtime-cu12>=12.1.105 +nvidia-cudnn-cu12>=9.1,<10.0 +nvidia-cufft-cu12>=11.0.2.54 +nvidia-cusolver-cu12>=11.4.5.107 +nvidia-cusparse-cu12>=12.1.0.106 +nvidia-nccl-cu12>=2.18.1 +nvidia-nvjitlink-cu12>=12.1.105 diff --git a/build/requirements.in b/build/requirements.in index d4e13d943..ec7fc71b0 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -2,6 +2,7 @@ # test deps # -r test-requirements.txt +-r gpu-test-requirements.txt # # build deps diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 290c7e732..dd7b6a55f 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -380,6 +380,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # via -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # via -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # via -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index f73065950..656458004 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index feebc33dc..2f5a25e2b 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 0a32888f6..f9635ec2b 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -434,6 +434,64 @@ numpy==2.1.2 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index dfefaf042..129874cba 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/jax/BUILD b/jax/BUILD index fb2e59864..3f75d60a0 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -14,7 +14,7 @@ # JAX is Autograd and XLA -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", @@ -45,17 +45,26 @@ package( licenses(["notice"]) -# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and -# assume it has been installed, e.g., by `pip`. -bool_flag( +# The flag controls whether jaxlib should be built by Bazel. +# If ":build_jaxlib=true", then jaxlib will be built. +# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel +# is available in the "dist" folder. +# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( name = "build_jaxlib", - build_setting_default = True, + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], ) config_setting( name = "enable_jaxlib_build", flag_values = { - ":build_jaxlib": "True", + ":build_jaxlib": "true", }, ) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 79aebcd86..1f4e5a08d 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -49,7 +49,7 @@ py_library_providing_imports_info( config_setting( name = "disable_jaxlib_for_cpu_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "False", }, ) @@ -57,7 +57,23 @@ config_setting( config_setting( name = "disable_jaxlib_for_cuda12_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "True", }, -) \ No newline at end of file +) + +config_setting( + name = "enable_py_import_for_cpu_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "False", + }, +) + +config_setting( + name = "enable_py_import_for_cuda12_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "True", + }, +) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 58a83d9b0..c4ac8d00f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -224,7 +224,15 @@ def if_building_jaxlib( "@pypi_jax_cuda12_plugin//:pkg", "@pypi_jax_cuda12_pjrt//:pkg", ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]): + if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], + if_py_import = [ + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], + if_py_import_for_cpu = [ + "//jaxlib/tools:jaxlib_py_import", + ]): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. @@ -234,12 +242,16 @@ def if_building_jaxlib( if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of gpu-enabled builds if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds + if_py_import: the py_import targets to depend on in case of gpu-enabled builds + if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, }) # buildifier: disable=function-docstring diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index baf996d50..5b24d2359 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load( "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", @@ -228,6 +232,18 @@ string_flag( build_setting_default = "dist", ) +NVIDIA_WHEELS_DEPS = [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", +] + jax_wheel( name = "jaxlib_wheel", no_abi = False, @@ -235,6 +251,11 @@ jax_wheel( wheel_name = "jaxlib", ) +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + jax_wheel( name = "jaxlib_wheel_editable", editable = True, @@ -252,6 +273,12 @@ jax_wheel( wheel_name = "jax_cuda12_plugin", ) +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda_plugin_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, @@ -290,6 +317,12 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda_pjrt_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, From 30a9e1b3bfb39e1cf86cfdbf5bf5e97e238169ef Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 11 Mar 2025 09:52:35 -0700 Subject: [PATCH 02/28] [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell This one is particularly annoying, because we have to break up the MMA into two collective N=256 MMAs. However, TensorCore only updates a contiguous chunk of columns in TMEM and so after executing two of those we end up with a TMEM layout that looks like this: ``` Contributing CTA | 0 | 1 | 0 | 1 | N local | 0:128 | 0:128 | 128:256 | 128:256 | N | 0:128 | 256:384 | 128:256 | 384:512 | ``` You can see that the TMEM columns no longer monotonically go over all columns until N=512, but they include a number of jumps! We could fix this on the load side, by ensuring that each CTA in the group does a strided load along the tiled dimension, but that just seems more trouble than it's worth (and is not that well supported by TMA unless we increase the number of striding levels). Instead, we encode this weirdness in the TMEM layout we use and make sure to rearrange the data properly while loading the tiles into registers. PiperOrigin-RevId: 735791426 --- jax/experimental/mosaic/gpu/core.py | 2 +- .../mosaic/gpu/examples/matmul_blackwell.py | 4 +- jax/experimental/mosaic/gpu/tcgen05.py | 154 ++++++++++++------ jax/experimental/mosaic/gpu/utils.py | 4 + tests/mosaic/gpu_test.py | 2 +- 5 files changed, 115 insertions(+), 51 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index c92074dc2..66e19bb5f 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -260,7 +260,7 @@ def _construct_smem_reftree( dynamic_smem, c(dynamic_smem_offset, index), [], ) if layout is None: - layout = tcgen05._infer_tmem_layout(shape) + layout = tcgen05._infer_tmem_layout(shape, collective) num_cols = layout.cols_in_shape(shape) delayed_warp_init.append( functools.partial( diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index d15cecbdc..6af394d00 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -230,8 +230,8 @@ def main(unused_argv): tile_n *= 2 if m < tile_m or n < tile_n: continue - if kwargs["collective"] and tile_n >= 512: - continue # TODO(apaszke): Support 512 + if tile_n > 512: + continue if (m // tile_m) % kwargs["grid_tile_m"]: continue try: diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 7a349f50c..e5a2d3aa5 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -83,6 +83,7 @@ def mma( accumulate: ir.Value | bool = True, collective: bool = False, ): + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(accumulate, bool): accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) @@ -112,6 +113,10 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) + if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)): + raise ValueError( + f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" + ) f32 = ir.F32Type.get() if element_type == f32 or element_type == ir.BF16Type.get(): if d.dtype != f32: @@ -136,11 +141,7 @@ def mma( raise ValueError(f"N must be a multiple of 8, got: {n}") elif n > 256 and n != 512: raise ValueError("Only N below 256 or N=512 are supported") - if num_cta == 2 and n > 256: - raise NotImplementedError( - "N is too big for collective MMA. Only up to 256 is supported." - ) - n_group_elems = min(n, 256) + n_group_elems = min(n, 256 // num_cta) if m % m_group_elems: raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") if k % k_group_elems: @@ -179,6 +180,7 @@ def mma( # Step 4. Issue the instructions. true = arith.constant(ir.IntegerType.get_signless(1), 1) + n_collective_group_elems = n_group_elems * num_cta for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): a_offset = mi * a_m_group_stride + ki * a_k_group_stride a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) @@ -188,9 +190,9 @@ def mma( raise NotImplementedError("D needs to be sliced") acc = accumulate if ki == 0 else true _do_mma( - d.slice( - slice(None), utils.ds(ni * n_group_elems, n_group_elems) - ).address, + arith.addi( + d.address, arith.constant(i32, ni * n_collective_group_elems) + ), a_mk, b_nk, d_type=ir.F32Type.get(), @@ -377,8 +379,15 @@ class TMEMLayout: +------------------+------------------+ | [0:64, 64:128] | [64:128, 64:128] | +------------------+------------------+ + + The above is further complicated by column_tile_stride, which is used to + swizzle the ordering of column tiles. That is, if column_tile_stride is 2, + we will first lay out all tiles that have the column index 0, 2, 4, and so on + until we run out of tiles. Only then we lay out the tiles with column index + 1, 3, etc. """ elements_in_tile: tuple[int, int] + column_tile_stride: int = 1 def __post_init__(self): row_tiling = self.elements_in_tile[0] @@ -405,7 +414,7 @@ class TMEMLayout: return num_tiles // tiles_in_row * cols_in_tile -def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout: +def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: if shape[0] > TMEM_ROWS: raise ValueError( "Can only infer TMEM layout for shapes with at most 128 rows, got:" @@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout: "Can only infer TMEM layout for shapes with row count that's a power of" f" 2, got: {shape[0]}" ) - return TMEMLayout(elements_in_tile=(shape[0], 1)) + if shape[1] % 8: + raise ValueError( + "Can only infer TMEM layout for shapes with column count that's a" + f" multiple of 8, got: {shape[1]}" + ) + if collective and shape[1] == 512: + return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2) + else: + return TMEMLayout(elements_in_tile=(shape[0], 8)) @dataclasses.dataclass(frozen=True) @@ -432,7 +449,14 @@ class TMEMRef: layout: TMEMLayout @classmethod - def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None): + def from_alloc( + cls, + tmem_addr_ref: ir.Value, + shape: tuple[int, int], + dtype, + collective: bool | None = None, + layout: TMEMLayout | None = None, + ): i32 = ir.IntegerType.get_signless(32) if not ir.MemRefType.isinstance(tmem_addr_ref.type): raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}") @@ -449,7 +473,11 @@ class TMEMRef: if shape[0] < 32: raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}") if layout is None: - layout = _infer_tmem_layout(shape) + if collective is None: + raise ValueError( + "collective argument must be provided when TMEM layout is inferred" + ) + layout = _infer_tmem_layout(shape, collective) else: layout.check_shape(shape) # TODO: Do we have to do this?? @@ -461,12 +489,17 @@ class TMEMRef: base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") - if self.layout.elements_in_tile[0] != TMEM_ROWS: + if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): raise NotImplementedError( - f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows" + "Slicing only implemented for refs with standard layout, got:" + f" {self.layout}" ) if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS: raise NotImplementedError("TMEM cannot be sliced along rows") + if slice_shape[1] % 8: + raise NotImplementedError( + "TMEM column slice length must be a multiple of 8" + ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx) @@ -484,48 +517,75 @@ class TMEMRef: raise ValueError("TMEM loads only support slicing") if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: raise NotImplementedError("Slicing of TMEM not impelmented yet") - if self.layout.elements_in_tile[0] != TMEM_ROWS: - raise NotImplementedError( - f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows" - ) if self.shape[1] % 8: raise NotImplementedError if self.dtype != ir.F32Type.get(): raise NotImplementedError(self.dtype) layout = _m128_256bit_32bit_layout(self.shape) regs_shape = layout.registers_shape(self.shape) - num = self.shape[1] // 8 - # TODO(apaszke): Make the tiling configurable through the args too. - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 - else: - raise NotImplementedError(num) - registers = np.empty(regs_shape, dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - cols_per_num_tile = 8 # This depends on the 16x256b below. - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), + if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + # load_32xcols returns a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + registers = _load_32xcols( + self.address, self.shape[1], self.dtype + ).T.reshape(regs_shape) + elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2): + if self.shape[1] % 128 != 0: + raise ValueError( + f"TMEM layout {self.layout} is not compatible with shape {self.shape}" ) - regs += tmem_load(addr_row_col, "16x256b", num_tiling) - regs = [llvm.bitcast(self.dtype, r) for r in regs] - vector_regs = [] - undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype)) - for r_low, r_high in zip(regs[::2], regs[1::2]): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs.append(vreg) - # Dimension 4 is the one where we split 32 rows into tiles of 8. - regs_slice = (slice(None),) * 4 + (slice(row_group * 2, (row_group + 1) * 2),) - registers[regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[regs_slice].shape) + num_column_tiles = self.shape[1] // 128 + column_tile_stride = self.layout.column_tile_stride + num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) + tiles = [] + for col_tile_base in range(num_strided_col_groups): + for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): + tiles.append( + _load_32xcols( + arith.addi(self.address, arith.constant(i32, col_tile * 128)), + cols=128, + dtype=self.dtype, + ) + ) + registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + else: + raise NotImplementedError( + f"Loads only implemented for refs with standard layout, got: {self.layout}" + ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) +def _load_32xcols(base_addr, cols, dtype): + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + i32 = ir.IntegerType.get_signless(32) + assert cols % 8 == 0 + cols_per_num_tile = 8 + load_shape = "16x256b" + num = cols // 8 + if num <= 32: + num_tiling = num + elif num == 64: + num_tiling = 32 + else: + raise NotImplementedError(num) + vector_regs = np.ndarray((4, num), dtype=object) + # We load 16 lanes at a time, but need 32 in total. + for row_group in range(2): + addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) + regs = [] + for num_group in range(num // num_tiling): + addr_row_col = arith.addi( + addr_row, + arith.constant(i32, num_tiling * num_group * cols_per_num_tile), + ) + regs += tmem_load(addr_row_col, load_shape, num_tiling) + regs = [llvm.bitcast(dtype, r) for r in regs] + undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): + high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) + vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + return vector_regs + def _m128_256bit_32bit_layout(shape: tuple[int, ...]): if len(shape) != 2: diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index f90f7ff08..080397bbb 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1201,3 +1201,7 @@ def bitcast(x: ir.Value, new_type: ir.Type): assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) raise ValueError(f"Can't bitcast {x.type} to {new_type}") + + +def ceil_div(x: int, y: int): + return (x + y - 1) // y diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cc654eb2b..f6b94b777 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1026,7 +1026,7 @@ class TCGen05Test(TestCase): in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation m=(256,), # TODO(apaszke): 64, 192, 256 - n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2 + n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), ) From 4ae3211ea242e7c735c60fda24953c658124c3a8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 11 Mar 2025 09:20:11 -0700 Subject: [PATCH 03/28] jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version --- jax/_src/lax/control_flow/loops.py | 4 ++-- tests/lax_control_flow_test.py | 22 ++++++++++++++++++++++ tests/lax_scipy_special_functions_test.py | 9 +++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c7dee3e71..3084fa722 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1364,9 +1364,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") if config.disable_jit.value: try: - val = init_val + val = tree_map(lax.asarray, init_val) while cond_fun(val): - val = body_fun(val) + val = tree_map(lax.asarray, body_fun(val)) return val except core.ConcretizationTypeError: # Can't run this while_loop in Python (e.g. because there's a vmap diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 15fc37805..39da45079 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2502,6 +2502,28 @@ class LaxControlFlowTest(jtu.JaxTestCase): x, n = jnp.arange(3), jnp.arange(4) jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash + def test_disable_jit_while_loop_with_mutation(self): + # https://github.com/jax-ml/jax/issues/27019 + + def body_fun(carry): + x, y = carry + x += 1 # in-place if x is mutable + return x, y + x + + def cond_fun(carry): + x, _ = carry + return x < 10 + + def f(): + val = np.array(1.0) # mutable value + return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1] + + with jax.disable_jit(False): + result_jit = f() + with jax.disable_jit(True): + result_nojit = f() + self.assertEqual(result_jit, result_nojit) + @parameterized.named_parameters( {"testcase_name": f"_{shape}_{axis=}", "shape": shape, "axis": axis} diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 96d48dcd3..8932869f1 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -278,6 +278,15 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): with jax.checking_leaks(): lsp_special.expi(jnp.ones(())) + def testExpiDisableJit(self): + # Regression test for https://github.com/jax-ml/jax/issues/27019 + x = jnp.array([-0.5]) + with jax.disable_jit(True): + result_nojit = lsp_special.expi(x) + with jax.disable_jit(False): + result_jit = lsp_special.expi(x) + self.assertAllClose(result_jit, result_nojit) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 6f7ce9d0486b3f9f9ae3f0f5dcaf62800004fdf5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 11 Mar 2025 10:27:37 -0700 Subject: [PATCH 04/28] Skip ASAN tests for the big Mosaic GPU tests They are timing out. PiperOrigin-RevId: 735804647 --- tests/mosaic/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 4b231939a..4ec2dbf3b 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -39,7 +39,10 @@ jax_multiplatform_test( ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 16, - tags = ["multiaccelerator"], + tags = [ + "multiaccelerator", + "noasan", # Times out. + ], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), From d191927b24cbe83944552ee0222f8384dc7cff75 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 11 Mar 2025 10:36:32 -0700 Subject: [PATCH 05/28] Fix syntax error and typos for composite primitive docstring. PiperOrigin-RevId: 735808000 --- docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 34 ++++++++++++++++++---------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 5f51cdb3b..9db79f591 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -58,6 +58,7 @@ Operators clz collapse complex + composite concatenate conj conv diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 99760099d..d4131e69b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1489,14 +1489,14 @@ def composite( ): """Composite with semantics defined by the decomposition function. - A composite is a higher-order JAX function that encapsulates an operation mad + A composite is a higher-order JAX function that encapsulates an operation made up (composed) of other JAX functions. The semantics of the op are implemented by the ``decomposition`` function. In other words, the defined composite function can be replaced with its decomposed implementation without changing the semantics of the encapsulated operation. The compiler can recognize specific composite operations by their ``name``, - ``version``, ``kawargs``, and dtypes to emit more efficient code, potentially + ``version``, ``kwargs``, and dtypes to emit more efficient code, potentially leveraging hardware-specific instructions or optimizations. If the compiler doesn't recognize the composite, it falls back to compiling the ``decomposition`` function. @@ -1505,11 +1505,11 @@ def composite( be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could recognize the "tangent" composite and emit a single ``tangent`` instruction instead of three separate instructions (``sin``, ``divide``, and ``cos``). - With compilers for hardwares without dedicated tangent support, it would fall - back to compiling the decomposition. + For hardware without dedicated tangent support, it would fall back to + compiling the decomposition. - This is useful for preserving high level abstraction that would otherwise be - lost while lowering which allows for easier pattern-matching in low-level IR. + This is useful for preserving high-level abstractions that would otherwise be + lost while lowering, which allows for easier pattern-matching in low-level IR. Args: decomposition: function that implements the semantics of the composite op. @@ -1517,19 +1517,20 @@ def composite( version: optional int to indicate semantic changes to the composite. Returns: - out: callable composite function. Note that positional arguments to this - function should be interpreted as inputs and keyword arguments should be - interpreted as attributes of the op. Any keyword arguments that are passed - with ``None`` as a value will be omitted from the - ``composite_attributes``. + Callable: Returns a composite function. Note that positional arguments to + this function should be interpreted as inputs and keyword arguments should + be interpreted as attributes of the op. Any keyword arguments that are + passed with ``None`` as a value will be omitted from the + ``composite_attributes``. Examples: Tangent kernel: + >>> def my_tangent_composite(x): ... return lax.composite( - ... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent' + ... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent" ... )(x) - ... + >>> >>> pi = jnp.pi >>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): @@ -1538,9 +1539,10 @@ def composite( [ 0. 1. -1. 0.] [ 0. 1. -1. 0.] - The recommended way to create composites is via a decorator. Use `/` and `*` - in the function signature to be explicit about positional and keyword - arguments respectively: + The recommended way to create composites is via a decorator. Use ``/`` and + ``*`` in the function signature to be explicit about positional and keyword + arguments, respectively: + >>> @partial(lax.composite, name="my.softmax") ... def my_softmax_composite(x, /, *, axis): ... return jax.nn.softmax(x, axis) From 7ac6355262982a21b8a1aace47eec5f40fbc3e40 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 11 Mar 2025 11:41:34 -0700 Subject: [PATCH 06/28] Add TPU test jobs to the new CI continuous and nightly/release test workflows Also, modify the TPU presubmit workflow to reuse the `build_artifacts.yml` and `pytest_tpu.yml` PiperOrigin-RevId: 735832964 --- .github/workflows/cloud-tpu-ci-presubmit.yml | 85 ++++------ .github/workflows/pytest_tpu.yml | 149 ++++++++++++++++++ .github/workflows/wheel_tests_continuous.yml | 26 +++ .../workflows/wheel_tests_nightly_release.yml | 25 +++ ci/envs/default.env | 12 +- ci/run_pytest_tpu.sh | 50 +++--- ci/utilities/install_wheels_locally.sh | 11 ++ 7 files changed, 281 insertions(+), 77 deletions(-) create mode 100644 .github/workflows/pytest_tpu.yml diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index 4dc8e55fb..a92e3cc19 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -3,6 +3,7 @@ # This job currently runs as a non-blocking presubmit. It is experimental and is currently being # tested to get to a stable state before we enable it as a blocking presubmit. name: CI - Cloud TPU (presubmit) + on: workflow_dispatch: inputs: @@ -33,64 +34,32 @@ concurrency: cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: - cloud-tpu-test: + build-jax-artifacts: if: github.event.repository.fork == false -# Begin Presubmit Naming Check - name modification requires internal check to be updated + uses: ./.github/workflows/build_artifacts.yml strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - tpu: [ - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} - ] - python-version: ["3.10"] - name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})" -# End Presubmit Naming Check github-tpu-presubmits - env: - JAXCI_PYTHON: python${{ matrix.python-version }} - JAXCI_TPU_CORES: ${{ matrix.tpu.cores }} + fail-fast: false # don't cancel all jobs on failure + matrix: + artifact: ["jax", "jaxlib"] + with: + runner: "linux-x86-n2-16" + artifact: ${{ matrix.artifact }} + python: "3.10" + clone_main_xla: 1 + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - runs-on: ${{ matrix.tpu.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - - timeout-minutes: 60 - - defaults: - run: - shell: bash -ex {0} - steps: - # https://opensource.google/documentation/reference/github/services#actions - # mandates using a specific commit for non-Google actions. We use - # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # Checkout XLA at head, if we're building jaxlib at head. - - name: Checkout XLA at head - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: openxla/xla - path: xla - # We need to mark the GitHub workspace as safe as otherwise git commands will fail. - - name: Mark GitHub workspace as safe - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - - name: Install JAX test requirements - run: | - $JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt - - name: Build jaxlib at head with latest XLA - run: | - # Build and install jaxlib at head - $JAXCI_PYTHON build/build.py build --wheels=jaxlib \ - --python_version=${{ matrix.python-version }} \ - --bazel_options=--config=rbe_linux_x86_64 \ - --local_xla_path="$(pwd)/xla" \ - --verbose - - # Install libtpu - $JAXCI_PYTHON -m uv pip install --pre libtpu \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Install jaxlib wheel and run tests - run: ./ci/run_pytest_tpu.sh \ No newline at end of file + run-pytest-tpu: + if: github.event.repository.fork == false + needs: [build-jax-artifacts] + uses: ./.github/workflows/pytest_tpu.yml + # Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "TPU test (jaxlib=head, v5e-8)" + with: + runner: "linux-x86-ct5lp-224-8tpu" + cores: "8" + tpu-type: "v5e-8" + python: "3.10" + libtpu-version-type: "nightly" + gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} + # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml new file mode 100644 index 000000000..2341bfb79 --- /dev/null +++ b/.github/workflows/pytest_tpu.yml @@ -0,0 +1,149 @@ +# CI - Pytest TPU +# +# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via +# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib wheel from a GCS bucket. +# - Sets up the libtpu wheels. +# - Executes the `run_pytest_cpu.sh` script, which performs the following actions: +# - Installs the downloaded jaxlib wheel. +# - Runs the TPU tests with Pytest. +name: CI - Pytest TPU + +on: + workflow_call: + inputs: + # Note that the values for runners, cores, and tpu-type are linked to each other. + # For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the + # following mapping: + # {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + # {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-ct5lp-224-8tpu" + cores: + description: "How many TPU cores should the test use?" + type: string + required: true + default: "8" + tpu-type: + description: "Which TPU type is used for testing?" + type: string + required: true + default: "v5e-8" + python: + description: "Which Python version should be used for testing?" + type: string + required: true + default: "3.12" + run-full-tpu-test-suite: + description: "Should the full TPU test suite be run?" + type: string + required: false + default: "0" + libtpu-version-type: + description: "Which libtpu version should be used for testing?" + type: string + required: false + # Choices are: + # - "nightly": Use the nightly libtpu wheel. + # - "pypi_latest": Use the latest libtpu wheel from PyPI. + # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. + default: "nightly" + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + required: true + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false + +jobs: + run-tests: + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + # Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" + # End Presubmit Naming Check github-tpu-presubmits + + env: + LIBTPU_OLDEST_VERSION_DATE: 20241205 + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" + JAXCI_TPU_CORES: "${{ inputs.cores }}" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set env vars for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + # Python wheels follow a naming convention: standard wheels use the pattern + # `*-cp-cp-*`, while free-threaded wheels use + # `*-cp-cpt-*`. + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV + - name: Download JAX wheels from GCS + id: download-wheel-artifacts + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in the step below so that we can print a more + # informative error message. + continue-on-error: true + run: | + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + - name: Skip the test run if the wheel artifacts were not downloaded successfully + if: steps.download-wheel-artifacts.outcome == 'failure' + run: | + echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" + echo "built successfully by the artifact build jobs and are available in the GCS bucket." + echo "Skipping the test run." + exit 1 + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt + - name: Set up libtpu wheels + run: | + if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then + echo "Using nightly libtpu" + $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then + echo "Using latest libtpu from PyPI" + # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` + # script will install the latest libtpu wheel from PyPI. + echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV + elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then + echo "Using oldest supported libtpu" + $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + else + echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" + exit 1 + fi + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest TPU tests + timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }} + run: ./ci/run_pytest_tpu.sh diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 5c818bf56..f12d8a7f0 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -142,4 +142,30 @@ jobs: python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-pytest-tpu: + # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated + # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we + # still want to run the tests for other platforms. + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact] + uses: ./.github/workflows/pytest_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10",] + tpu-specs: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] + name: "TPU tests (jax=head, jaxlib=head)" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index b88b000e4..4845fb0f2 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -58,4 +58,29 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + run-pytest-tpu: + uses: ./.github/workflows/pytest_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10","3.11", "3.12", "3.13"] + tpu-specs: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] + libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] + exclude: + - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} + name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: ${{ matrix.libtpu-version-type }} gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env index 7a2448944..a5a5d56eb 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. -export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} \ No newline at end of file +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} + +# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests +# is run. +export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} + +# We use this environment variable to control which additional wheels to install +# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest +# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the +# list of valid values and their behavior. +export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index feaccea8e..9b4a3bbfd 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -52,23 +52,37 @@ export JAX_SKIP_SLOW_TESTS=true echo "Running TPU tests..." -# Run single-accelerator tests in parallel -JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" \ - tests/pallas/ops_test.py \ - tests/pallas/export_back_compat_pallas_test.py \ - tests/pallas/export_pallas_test.py \ - tests/pallas/tpu_ops_test.py \ - tests/pallas/tpu_pallas_test.py \ - tests/pallas/tpu_pallas_random_test.py \ - tests/pallas/tpu_pallas_async_test.py \ - tests/pallas/tpu_pallas_state_test.py +if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" tests examples -# Run Pallas printing tests, which need to run with I/O capturing disabled. -TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + # Run Pallas printing tests, which need to run with I/O capturing disabled. + TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \ + tests/pallas/tpu_pallas_test.py::PallasCallPrintTest -# Run multi-accelerator across all chips -"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ - tests/pjit_test.py \ - tests/pallas/tpu_pallas_distributed_test.py \ No newline at end of file + # Run multi-accelerator across all chips + "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests +else + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" \ + tests/pallas/ops_test.py \ + tests/pallas/export_back_compat_pallas_test.py \ + tests/pallas/export_pallas_test.py \ + tests/pallas/tpu_ops_test.py \ + tests/pallas/tpu_pallas_test.py \ + tests/pallas/tpu_pallas_random_test.py \ + tests/pallas/tpu_pallas_async_test.py \ + tests/pallas/tpu_pallas_state_test.py + + # Run Pallas printing tests, which need to run with I/O capturing disabled. + TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + + # Run multi-accelerator across all chips + "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ + tests/pjit_test.py \ + tests/pallas/tpu_pallas_distributed_test.py +fi \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 41274b95f..f98f7658a 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -17,8 +17,19 @@ # Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python # binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to # avoid using the Windows version of `find` on Msys. + WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) +for i in "${!WHEELS[@]}"; do + if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then + if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then + # Append [tpu] to the jax wheel name to download the latest libtpu wheel + # from PyPI. + WHEELS[$i]="${WHEELS[$i]}[tpu]" + fi + fi +done + if [[ -z "${WHEELS[@]}" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" exit 1 From 82b2591b219e8797aa4e98ad83a6758aece765d9 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Apr 2024 23:14:43 +0300 Subject: [PATCH 07/28] Fix scipy.special.gammainc/gammaincc evaluation at boundary points --- jax/_src/lax/special.py | 30 +++++++++++-------- jax/_src/scipy/stats/gamma.py | 3 +- .../jax2tf/tests/jax2tf_limitations.py | 12 ++++---- tests/lax_scipy_special_functions_test.py | 20 +++++++++++++ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index f6cd8bc0d..4a36bf186 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -303,15 +303,16 @@ def _igamma_series(ax, x, a, enabled, dtype, mode): def igamma_impl(a, x, *, dtype): is_nan = bitwise_or(_isnan(a), _isnan(x)) - x_is_zero = eq(x, _const(x, 0)) x_is_infinity = eq(x, _const(x, float('inf'))) - domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0))) - use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a)) + a_is_zero = eq(a, _const(a, 0)) + x_is_zero = eq(x, _const(x, 0)) + domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)]) + + use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a)) ax = a * log(x) - x - lgamma(a) underflow = lt(ax, -log(dtypes.finfo(dtype).max)) ax = exp(ax) - enabled = bitwise_not( - _reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan])) + enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity])) output = select( use_igammac, @@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype): ) output = select(x_is_zero, full_like(a, 0), output) output = select(x_is_infinity, full_like(a, 1), output) - output = select(bitwise_or(domain_error, is_nan), - full_like(a, float('nan')), output) + output = select(domain_error, full_like(a, float('nan')), output) return output def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): @@ -433,11 +433,15 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): raise ValueError(f"Invalid mode: {mode}") def igammac_impl(a, x, *, dtype): - out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) + is_nan = bitwise_or(_isnan(a), _isnan(x)) + a_is_zero = eq(a, _const(a, 0)) + x_is_zero = eq(x, _const(x, 0)) + x_is_infinity = eq(x, _const(x, float('inf'))) + domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)]) use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) ax = a * log(x) - x - lgamma(a) underflow = lt(ax, -log(dtypes.finfo(dtype).max)) - enabled = bitwise_not(bitwise_or(out_of_range, underflow)) + enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero])) ax = exp(ax) igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma), @@ -445,10 +449,10 @@ def igammac_impl(a, x, *, dtype): igammac_cf_call = _igammac_continued_fraction(ax, x, a, bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE) - result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) - x_is_infinity = eq(x, _const(x, float('inf'))) - result = select(x_is_infinity, full_like(result, 0), result) - return select(out_of_range, full_like(a, 1), result) + output = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) + output = select(bitwise_or(x_is_infinity, a_is_zero), full_like(output, 0), output) + output = select(domain_error, full_like(a, float('nan')), output) + return output def igamma_grad_a_impl(a, x, *, dtype): is_nan = bitwise_or(_isnan(a), _isnan(x)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 4343c0802..97d73a3ee 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -198,7 +198,8 @@ def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> - :func:`jax.scipy.stats.gamma.logsf` """ x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) - return gammaincc(a, lax.div(lax.sub(x, loc), scale)) + y = lax.div(lax.sub(x, loc), scale) + return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y)) def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index c3b9e96dc..63f019b31 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -865,15 +865,15 @@ class Jax2TfLimitation(test_harnesses.Limitation): def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # noqa: F811 arg1, arg2 = args - # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN + # lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1 special_cases = (arg1 <= 0.) | (arg2 <= 0) nr_special_cases = np.count_nonzero(special_cases) tst.assertAllClose( - np.full((nr_special_cases,), 1., dtype=dtype), + np.full((nr_special_cases,), np.nan, dtype=dtype), result_jax[special_cases], err_msg=err_msg) tst.assertAllClose( - np.full((nr_special_cases,), np.nan, dtype=dtype), + np.full((nr_special_cases,), 1, dtype=dtype), result_tf[special_cases], err_msg=err_msg) # non-special cases are equal @@ -892,12 +892,12 @@ class Jax2TfLimitation(test_harnesses.Limitation): custom_numeric(dtypes=[np.float64], tol=1e-9), custom_numeric(devices="gpu", tol=1e-3), custom_numeric( + modes=("compiled",), custom_assert=custom_assert, - devices=("cpu", "gpu"), + devices=("cpu", "gpu", "tpu"), description=( "May return different results at undefined points " - "(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or " - "JAX returns 1 and TF returns `NaN`")), + "(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")), ] @classmethod diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 8932869f1..2c09252f4 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -287,6 +287,26 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): result_jit = lsp_special.expi(x) self.assertAllClose(result_jit, result_nojit) + def testGammaIncBoundaryValues(self): + dtype = jax.numpy.zeros(0).dtype # default float dtype. + nan = float('nan') + inf = float('inf') + args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype), + np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) + + def testGammaIncCBoundaryValues(self): + dtype = jax.numpy.zeros(0).dtype # default float dtype. + nan = float('nan') + inf = float('inf') + args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype), + np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 0db14aa34257a938d9d14c7a97122dd13196d132 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Mar 2025 12:32:58 -0700 Subject: [PATCH 08/28] Add NVIDIA wheel requirements only for Linux builds. PiperOrigin-RevId: 735850240 --- build/BUILD.bazel | 6 +++--- build/gpu-test-requirements.txt | 20 ++++++++++---------- build/requirements_lock_3_10.txt | 20 ++++++++++---------- build/requirements_lock_3_11.txt | 20 ++++++++++---------- build/requirements_lock_3_12.txt | 20 ++++++++++---------- build/requirements_lock_3_13.txt | 20 ++++++++++---------- build/requirements_lock_3_13_ft.txt | 20 ++++++++++---------- 7 files changed, 63 insertions(+), 63 deletions(-) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index cf43fdab0..f088cd58a 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -29,7 +29,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = True, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) compile_pip_requirements( @@ -44,7 +44,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = False, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) compile_pip_requirements( @@ -58,7 +58,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = False, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) py_library( diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt index 54f5d2ab3..ff43f91ba 100644 --- a/build/gpu-test-requirements.txt +++ b/build/gpu-test-requirements.txt @@ -1,13 +1,13 @@ # NVIDIA CUDA dependencies # Note that the wheels are downloaded only when the targets in bazel command # contain dependencies on these wheels. -nvidia-cublas-cu12>=12.1.3.1 -nvidia-cuda-cupti-cu12>=12.1.105 -nvidia-cuda-nvcc-cu12>=12.6.85 -nvidia-cuda-runtime-cu12>=12.1.105 -nvidia-cudnn-cu12>=9.1,<10.0 -nvidia-cufft-cu12>=11.0.2.54 -nvidia-cusolver-cu12>=11.4.5.107 -nvidia-cusparse-cu12>=12.1.0.106 -nvidia-nccl-cu12>=2.18.1 -nvidia-nvjitlink-cu12>=12.1.105 +nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" +nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" +nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" +nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" +nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" +nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" +nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index dd7b6a55f..b3bc7aff3 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -380,7 +380,7 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy -nvidia-cublas-cu12==12.8.3.14 \ +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 @@ -388,48 +388,48 @@ nvidia-cublas-cu12==12.8.3.14 \ # via -r build/test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 \ +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 \ +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 \ +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 \ +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 \ +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 \ +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via # via -r build/test-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 \ +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 656458004..f1e0574f0 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -375,7 +375,7 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy -nvidia-cublas-cu12==12.8.3.14 \ +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 @@ -383,48 +383,48 @@ nvidia-cublas-cu12==12.8.3.14 \ # -r build/test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 \ +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 \ +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 \ +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 \ +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 \ +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 \ +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via # -r build/test-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 \ +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 2f5a25e2b..931cd9070 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -375,7 +375,7 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy -nvidia-cublas-cu12==12.8.3.14 \ +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 @@ -383,48 +383,48 @@ nvidia-cublas-cu12==12.8.3.14 \ # -r build/test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 \ +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 \ +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 \ +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 \ +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 \ +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 \ +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via # -r build/test-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 \ +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index f9635ec2b..ab3784bbe 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -434,7 +434,7 @@ numpy==2.1.2 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy -nvidia-cublas-cu12==12.8.3.14 \ +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 @@ -442,48 +442,48 @@ nvidia-cublas-cu12==12.8.3.14 \ # -r build/test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 \ +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 \ +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 \ +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 \ +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 \ +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 \ +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via # -r build/test-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 \ +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 129874cba..e7a2968e9 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -390,7 +390,7 @@ numpy==2.2.1 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy -nvidia-cublas-cu12==12.8.3.14 \ +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 @@ -398,48 +398,48 @@ nvidia-cublas-cu12==12.8.3.14 \ # -r build/test-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 \ +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 \ +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 \ +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 \ +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 \ +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 \ +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 \ +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via # -r build/test-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 \ +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 \ +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 From eff612a3b6bb62a67b911906bf92c0608dbba551 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 11 Mar 2025 12:35:47 -0700 Subject: [PATCH 09/28] Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk. PiperOrigin-RevId: 735851301 --- .../pallas/ops/tpu/ragged_paged_attention.py | 24 ++++++++++++------- .../pallas/tpu_ragged_paged_attention_test.py | 8 ++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 30cb20733..6600d7650 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -43,14 +43,22 @@ class MultiPageAsyncCopyDescriptor: ): self._vmem_buf = vmem_buf seq_id, kv_pages_start = offset - self._async_copies = [ - pltpu.make_async_copy( - pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]], - vmem_buf.at[i], - sem, - ) - for i in range(vmem_buf.shape[0]) - ] + pages_per_seq = page_indices_ref.shape[1] + self._async_copies = [] + # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert + # a bunch of if-ops. Check the performance when we have benchmarking setup. + for i in range(vmem_buf.shape[0]): + page_idx = kv_pages_start + i + page_idx = jax.lax.select( + page_idx < pages_per_seq, page_idx, pages_per_seq - 1 + ) + self._async_copies.append( + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], + vmem_buf.at[i], + sem, + ) + ) def start(self): """Starts the async copies.""" diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index cca8e3bc8..bffcebc52 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -64,10 +64,6 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): max_num_seq = max(len(seq_lens), max_num_seq) max_kv_len = max(kv_lens) pages_per_seq = ceil_div(max_kv_len, page_size) - pages_per_seq = ( - ceil_div(pages_per_seq, num_kv_pages_per_block) - * num_kv_pages_per_block - ) num_q_heads, num_kv_heads = num_heads cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) @@ -130,8 +126,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): num_seqs=num_seqs, ) tols = { - "float32": 1e-1, - "bfloat16": 2e-1, + "float32": 0.15, + "bfloat16": 0.2, } tol = tols[jnp.dtype(dtype).name] self.assertAllClose(output, expected, atol=tol, rtol=tol) From e0545a71eb80d76d55586d996cc0231a98b33f3c Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Mar 2025 13:42:29 -0700 Subject: [PATCH 10/28] Remove installation of NVIDIA wheels for CPU tests PiperOrigin-RevId: 735875073 --- .github/workflows/pytest_cpu.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index e64f81809..1cfc7a883 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -116,6 +116,9 @@ jobs: exit 1 - name: Install Python dependencies run: | + # Remove installation of NVIDIA wheels for CPU tests. + sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in + # TODO(srnitin): Remove after uv is installed in the Windows Dockerfile $JAXCI_PYTHON -m pip install uv~=0.5.30 # python 3.13t cannot compile zstandard 0.23.0 due to From 67aa997f84791dda2f28694d3a9958da709d9ddc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 11 Mar 2025 13:44:37 -0700 Subject: [PATCH 11/28] Increase the number of iterations in a test that compares rolled versus unrolled HLO for length. A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations. PiperOrigin-RevId: 735875835 --- tests/lax_control_flow_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 39da45079..3871a87a7 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2445,7 +2445,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): assert b.shape == () return c, b - xs = jnp.ones((5, 3)) + xs = jnp.ones((20, 3)) c = jnp.ones(4) scan = lambda c, xs: lax.scan(f, c, xs) From 99c91060321b4bac2629b8d87d3022b4fa8b806c Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 11 Mar 2025 13:49:57 -0700 Subject: [PATCH 12/28] [Mosaic GPU] Replace `WGMMAFragLayout` with `TiledLayout` in the mlir dialect and use it in layout inference. `WGMMAFragLayout` will be completely removed soon. PiperOrigin-RevId: 735877661 --- .../mosaic/gpu/dialect_lowering.py | 17 +++- .../mosaic/gpu/layout_inference.py | 26 ++++--- jax/experimental/mosaic/gpu/layouts.py | 77 +++++++++++++++---- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 33 ++++---- tests/mosaic/gpu_layout_inference_test.py | 6 +- 5 files changed, 112 insertions(+), 47 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 368b47df4..560816319 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -259,14 +259,15 @@ def _vector_load_op_lowering_rule( is_signed=is_signed, vec_size=strided_layout.vec_size, ) - elif layouts.is_wgmma_fragmented_layout(out_layout_attr): + elif layouts.from_layout_attr(out_layout_attr) == fa.TILED_LAYOUT_WGMMA: layout = ir.MemRefType(vector_load_op.base.type).layout swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) transformed_ref = transform_memref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, - is_signed=is_signed + is_signed=is_signed, + layout=fa.TILED_LAYOUT_WGMMA, ) else: raise ValueError( @@ -634,7 +635,10 @@ def _mgpu_wgmma_op_lowering_rule( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), ) - if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)): + is_supported_layout = ( + lambda l: layouts.from_tiled_layout_attr(l) == fa.TILED_LAYOUT_WGMMA + ) + if not all(map(is_supported_layout, fa_layouts)): raise ValueError("Layout mismatch") wgmma_layout = fa_layouts[0] @@ -667,7 +671,12 @@ def _mgpu_wgmma_op_lowering_rule( new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) - return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)] + return [ + _fragmented_array_to_ir( + new_acc.value.to_layout(fa.TILED_LAYOUT_WGMMA), + wgmma_op.accumulator.type, + ) + ] @_register_lowering(mgpu.ArriveExpectTxOp) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 044e7537d..cd493f0ab 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -63,7 +63,7 @@ def _choose_representative_layout( Given the input set of possible layouts, this function extracts a single representative layout. Currently, this function only works with strided, - splat, and WGMMA fragmented layouts. + splat, and tiled layouts. Returns: A single layout that can be used to annotate the operation, or None if the @@ -86,18 +86,18 @@ def _choose_representative_layout( ) ) - wgmma_layouts: list[fa.WGMMAFragLayout] = list( + tiled_layouts: list[fa.TiledLayout] = list( map( layouts_lib.from_layout_attr, - filter(layouts_lib.is_wgmma_fragmented_layout, layouts), + filter(layouts_lib.is_tiled_layout, layouts), ) ) - if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len( + if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len( layouts ): raise ValueError( - f"Expected only strided, splat, and wgmma layouts, got {layouts}" + f"Expected only strided, splat, and tiled layouts, got {layouts}" ) if len(splat_layouts) > 1: @@ -112,13 +112,19 @@ def _choose_representative_layout( "is not supported." ) - if (wgmma_layouts and strided_layouts): + if len(tiled_layouts) > 1: raise NotImplementedError( - "Mixing strided and WGMMA layouts is not supported." + "Finding a representative layout for several distinct tiled layouts " + "is not supported." ) - if wgmma_layouts: - return layouts_lib.to_layout_attr(wgmma_layouts[0]) + if tiled_layouts and strided_layouts: + raise NotImplementedError( + "Mixing strided and tiled layouts is not supported." + ) + + if tiled_layouts: + return layouts_lib.to_layout_attr(tiled_layouts[0]) if strided_layouts: [strided_layout] = strided_layouts @@ -333,7 +339,7 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: - layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout()) + layout = layouts_lib.to_layout_attr(fa.TILED_LAYOUT_WGMMA) if ir.VectorType.isinstance(wgmma_op.a.type): return [layout, layout], [layout] diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 334ebeddd..5c3b23119 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: return bool(_strided_fragmented_layout_attr_pattern.search(str(attr))) +_tiled_layout_attr_pattern = re.compile( + r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," + r" warp_dim\s*=\s*(?P[-\d]+)," + r" lane_dims\s*=\s*\[(?P.*)\]," + r" vector_dim\s*=\s*(?P[-\d]+)>$" +) + + +def to_tiled_layout_attr( + layout: fa.TiledLayout, +) -> ir.Attribute: + """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" + tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + return ir.Attribute.parse( + f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," + f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + ) + + +_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") + + +def from_tiled_layout_attr( + attr: ir.Attribute, +) -> fa.TiledLayout: + """Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute. + + Raises: + ValueError: If the attribute is not a #mosaic_gpu.TiledLayout + attribute. + """ + match = _tiled_layout_attr_pattern.fullmatch(str(attr)) + if not match: + raise ValueError( + f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" + ) + + tiling_str = match.group("tiling") + tile_strings = [] + if len(tiling_str) > 2: + tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1]) + tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) + return fa.TiledLayout( + tiling=fa.Tiling(tiles), + warp_dim=int(match.group("warp_dim")), + lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), + vector_dim=int(match.group("vector_dim")) + ) + + +def is_tiled_layout(attr: ir.Attribute) -> bool: + return bool(_tiled_layout_attr_pattern.search(str(attr))) + + def to_layout_attr( layout: ( fa.WGSplatFragLayout | fa.WGStridedFragLayout - | fa.WGMMAFragLayout + | fa.TiledLayout | fa.WGMMARowFragLayout ), ) -> ir.Attribute: @@ -108,8 +164,8 @@ def to_layout_attr( return to_splat_fragmented_layout_attr(layout) case fa.WGStridedFragLayout(): return to_strided_fragmented_layout_attr(layout) - case fa.WGMMAFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMAFragLayout") + case fa.TiledLayout(): + return to_tiled_layout_attr(layout) case fa.WGMMARowFragLayout(): return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: @@ -118,15 +174,6 @@ def to_layout_attr( ) -_wgmma_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMAFragLayout$" -) - - -def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr))) - - _wgmma_row_fragmented_layout_attr_pattern = re.compile( r"^#mosaic_gpu.WGMMARowFragLayout$" ) @@ -141,7 +188,7 @@ def from_layout_attr( ) -> ( fa.WGSplatFragLayout | fa.WGStridedFragLayout - | fa.WGMMAFragLayout + | fa.TiledLayout | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" @@ -149,8 +196,8 @@ def from_layout_attr( return from_splat_fragmented_layout_attr(attr) elif is_strided_fragmented_layout(attr): return from_strided_fragmented_layout_attr(attr) - elif is_wgmma_fragmented_layout(attr): - return fa.WGMMAFragLayout() + elif is_tiled_layout(attr): + return from_tiled_layout_attr(attr) elif is_wgmma_row_fragmented_layout(attr): return fa.WGMMARowFragLayout() else: diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index dbc829832..0882986fc 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -128,7 +128,6 @@ def MosaicGPU_WGStridedFragLayout : AttrDef { let summary = "Annotates an array that is the result of a splat."; let description = [{ @@ -143,20 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "2D array that can be tiled by supported WGMMA shapes."; - let description = [{ - This layout annotates arrays that are fragmented across all threads in a - warpgroup that is executing a WGMMA operation. The shape of the array is - (m, n) where: - - m % 64 == 0 - - n % 8 == 0 - }]; - - let mnemonic = "WGMMAFragLayout"; - let assemblyFormat = ""; -} - def MosaicGPU_WGMMARowFragLayout : AttrDef { let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; let description = [{ @@ -169,6 +154,24 @@ def MosaicGPU_WGMMARowFragLayout : AttrDef { + let summary = "A layout derived from a tiling expression."; + let description = [{ + See mosaic/gpu/fragmented_array.py -> TiledLayout for more details. + }]; + + let parameters = (ins + "::mlir::ArrayAttr":$tiling, + "int":$warp_dim, + "::mlir::ArrayAttr":$lane_dims, + "int":$vector_dim + ); + let mnemonic = "TiledLayout"; + let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` " + "`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`"; +} + + // Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td // but it was not possible to reuse that definition. Including that file // pulls in ops definitions that we don't want and they fail to compile. diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 91debfe57..ea83d1583 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -210,7 +210,7 @@ class LayoutInferenceTest(parameterized.TestCase): for layout in [ mgpu.WGSplatFragLayout(shape), mgpu.WGStridedFragLayout(shape, vec_size=4), - mgpu.WGMMAFragLayout(), + mgpu.TILED_LAYOUT_WGMMA, ] ) def test_infer_layout_from_yield_op_in_layouts_for_for_op( @@ -278,7 +278,7 @@ class LayoutInferenceTest(parameterized.TestCase): mgpu.infer_layout(self.module) - wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout()) + wgmma_layout = layouts.to_layout_attr(mgpu.TILED_LAYOUT_WGMMA) self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout]) self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) @@ -312,7 +312,7 @@ class LayoutInferenceTest(parameterized.TestCase): @parameterized.parameters( mgpu.WGStridedFragLayout((32, 4), vec_size=1), - mgpu.WGMMAFragLayout(), + mgpu.TILED_LAYOUT_WGMMA, ) def test_infer_layout_picks_non_splat_layout_over_splat_layout( self, layout From 4df691ec0095e725a2077b7d516bec0bc08b471e Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 11 Mar 2025 14:12:08 -0700 Subject: [PATCH 13/28] Remove unsupported mac x86 CI build options PiperOrigin-RevId: 735885305 --- .bazelrc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 04bfcf2d7..fb938169b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -251,12 +251,6 @@ build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -# Mac x86 CI configs -build:ci_darwin_x86_64 --macos_minimum_os=11.0 -build:ci_darwin_x86_64 --config=macos_cache_push -build:ci_darwin_x86_64 --verbose_failures=true -build:ci_darwin_x86_64 --color=yes - # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 build:ci_darwin_arm64 --config=macos_cache_push From 13eb8d3ae79fb59e01c97c851782898b113eb71e Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Mar 2025 14:40:48 -0700 Subject: [PATCH 14/28] Upgrade `ml-dtypes` version in `py3.10`-`py3.13` hermetic python lock files. This change is needed to add testing of int2/uint2 dtypes via bazel in presubmit (see https://github.com/jax-ml/jax/pull/21395). PiperOrigin-RevId: 735895293 --- build/requirements_lock_3_10.txt | 43 +++++++++++++++++------------ build/requirements_lock_3_11.txt | 43 +++++++++++++++++------------ build/requirements_lock_3_12.txt | 43 +++++++++++++++++------------ build/requirements_lock_3_13.txt | 47 +++++++++++++++++--------------- 4 files changed, 100 insertions(+), 76 deletions(-) diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index b3bc7aff3..6ed6b59aa 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -304,24 +304,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index f1e0574f0..8446e8361 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -299,24 +299,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 931cd9070..0436ab6dd 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -299,24 +299,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index ab3784bbe..e74d40b79 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -347,28 +347,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.0 \ - --hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \ - --hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \ - --hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \ - --hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \ - --hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \ - --hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \ - --hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \ - --hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \ - --hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \ - --hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \ - --hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \ - --hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \ - --hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \ - --hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \ - --hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \ - --hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \ - --hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \ - --hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \ - --hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \ - --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \ - --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ From 29bfd00f9cab468affd634b3ec0ee43895a5e6a8 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 11 Mar 2025 15:05:20 -0700 Subject: [PATCH 15/28] [Pallas TPU] Fix preferred_element_type propagation in dot_general with const PiperOrigin-RevId: 735903687 --- jax/_src/pallas/mosaic/lowering.py | 34 ++++++++++++++++++++++++++++-- tests/pallas/tpu_ops_test.py | 21 ++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9bc20ed2c..762141270 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1853,7 +1853,13 @@ def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): def _dot_general_lowering_rule( - ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_ + ctx: LoweringRuleContext, + x, + y, + dimension_numbers, + precision, + preferred_element_type, + **_, ): (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out @@ -1894,10 +1900,34 @@ def _dot_general_lowering_rule( x = vector.broadcast(bcast_shape, x) if ctx.avals_in[1].shape != bcast_shape: y = vector.broadcast(bcast_shape, y) + red_dtype = ( + preferred_element_type if preferred_element_type else lhs_aval.dtype + ) red_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, - lhs_aval.update(shape=(lhs_aval.shape[0],)), + lhs_aval.update(shape=(lhs_aval.shape[0],), dtype=red_dtype), ) + + if lhs_aval.dtype != red_dtype: + lhs_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + lhs_aval.update(shape=lhs_aval.shape, dtype=red_dtype), + ) + if red_dtype == jnp.float32: + x = arith.extf(lhs_type, x) + else: + raise NotImplementedError(f"Unsupported {preferred_element_type=}") + + if rhs_aval.dtype != red_dtype: + rhs_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + rhs_aval.update(shape=rhs_aval.shape, dtype=red_dtype), + ) + if red_dtype == jnp.float32: + y = arith.extf(rhs_type, y) + else: + raise NotImplementedError(f"Unsupported {preferred_element_type=}") + acc = arith.ConstantOp( red_type, ir.DenseElementsAttr.get_splat(red_type, val) ) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 2011602d8..c4d600d23 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -470,6 +470,27 @@ class OpsTest(PallasBaseTest): expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) np.testing.assert_array_equal(out, expected) + def test_reduce_with_const(self): + m = 1 + d = 1024 + x = jnp.ones((m, d), jnp.bfloat16) + + def dot(x, y): + return jax.lax.dot_general( + x, + y, + (((1,), (1,)), ((), ())), + preferred_element_type=jnp.float32, + ) + + def kernel(x, out): + out[:] = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) + + run = pl.pallas_call(kernel, jax.ShapeDtypeStruct((m, 1), jnp.float32)) + output = run(x) + expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) + np.testing.assert_array_equal(output, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True From f45cbf334262f0c32e8569762f408f8c7813fa0e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 11 Mar 2025 15:24:54 -0700 Subject: [PATCH 16/28] Fix a bug where `full` and `use_mesh` outside jit did not work because the `shard` passed to `make_array_from_callback` was sharded on all devices instead of just 1 device. This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context. PiperOrigin-RevId: 735909962 --- jax/_src/lax/lax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d4131e69b..e6cbbf245 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3016,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) + shard = shard.addressable_data(0) return array.make_array_from_callback(shape, sharding, lambda _: shard) if sharding is not None and not sharding._is_concrete: From c6b164dc092f6c2d1d7d8be7088e495722993449 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 11 Mar 2025 16:35:06 -0700 Subject: [PATCH 17/28] [Pallas/Fuser] Add custom evaluate to allow/disallow transposes PiperOrigin-RevId: 735931978 --- jax/BUILD | 1 + jax/_src/pallas/fuser/BUILD | 26 ++++++++ jax/_src/pallas/fuser/__init__.py | 1 + jax/_src/pallas/fuser/block_spec.py | 25 +++----- jax/_src/pallas/fuser/custom_evaluate.py | 82 ++++++++++++++++++++++++ jax/_src/pallas/fuser/fuser_utils.py | 33 ++++++++++ jax/experimental/pallas/fuser.py | 1 + 7 files changed, 152 insertions(+), 17 deletions(-) create mode 100644 jax/_src/pallas/fuser/custom_evaluate.py create mode 100644 jax/_src/pallas/fuser/fuser_utils.py diff --git a/jax/BUILD b/jax/BUILD index 3f75d60a0..5e9cbd7a1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -690,6 +690,7 @@ pytype_strict_library( deps = [ ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", + "//jax/_src/pallas/fuser:custom_evaluate", "//jax/_src/pallas/fuser:fusable", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 18e136623..66bbac33a 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -32,6 +32,7 @@ pytype_strict_library( ], deps = [ ":block_spec", + ":custom_evaluate", ":fusable", ":fusion", ":jaxpr_fusion", @@ -44,6 +45,7 @@ pytype_strict_library( "block_spec.py", ], deps = [ + ":fuser_utils", "//jax", "//jax:ad_util", "//jax:api_util", @@ -119,3 +121,27 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "custom_evaluate", + srcs = ["custom_evaluate.py"], + deps = [ + ":fuser_utils", + "//jax", + "//jax:core", + "//jax:source_info_util", + "//jax:tree_util", + "//jax:util", + ], +) + +pytype_strict_library( + name = "fuser_utils", + srcs = ["fuser_utils.py"], + deps = [ + "//jax:api_util", + "//jax:core", + "//jax:partial_eval", + "//jax:tree_util", + ], +) diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index a9f6ce390..3295c8f10 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -16,6 +16,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec +from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate from jax._src.pallas.fuser.fusable import fusable as fusable from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 83b485107..d0767aeeb 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -26,15 +26,14 @@ from typing import Any, Callable, Protocol, Sequence import jax from jax import lax from jax._src import ad_util -from jax._src import api_util from jax._src import core from jax._src import custom_derivatives -from jax._src import linear_util as lu from jax._src import pjit from jax._src import tree_util from jax._src import util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core +from jax._src.pallas.fuser import fuser_utils import jax.numpy as jnp import numpy as np @@ -226,18 +225,6 @@ def _unwrap_block_spec_scalar_prefetch( return out_block_spec -def _make_jaxpr(f, *args, **kwargs): - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - flat_avals = [core.get_aval(x) for x in flat_args] - debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs) - flat_fun, out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(f, debug_info=debug_info), in_tree - ) - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - return jaxpr, consts, in_tree, out_tree - - def pull_block_spec( f: Callable, out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...], @@ -246,7 +233,9 @@ def pull_block_spec( grid: tuple[int | jax.Array, ...] | None = None, ): def wrapped(*args, **kwargs): - jaxpr, consts, in_tree, out_tree_ = _make_jaxpr(f, *args, **kwargs) + jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( + f, *args, **kwargs + ) # TODO(sharadmv): handle these consts better, they should correspond to # scalar prefetch. del consts, out_tree_ @@ -563,7 +552,9 @@ def make_kernel_function( def get_fusion_values( fusion: Callable, *args, **kwargs ) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]: - jaxpr, values, in_tree, out_tree = _make_jaxpr(fusion, *args, **kwargs) + jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr( + fusion, *args, **kwargs + ) assert len(values) == len(jaxpr.constvars), (jaxpr, values) out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars) read_usage_env = compute_usage(jaxpr, out_usages) @@ -1325,7 +1316,7 @@ def push_block_spec( flat_block_specs, in_tree_ = tree_util.tree_flatten( (in_spec_args, in_spec_kwargs) ) - jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs) + jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs) if in_tree != in_tree_: raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}') out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs) diff --git a/jax/_src/pallas/fuser/custom_evaluate.py b/jax/_src/pallas/fuser/custom_evaluate.py new file mode 100644 index 000000000..fff0f7d7e --- /dev/null +++ b/jax/_src/pallas/fuser/custom_evaluate.py @@ -0,0 +1,82 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for evaluating functions under certain constraints.""" +import dataclasses +from typing import Any + +from jax import lax +from jax._src import core +from jax._src import source_info_util +from jax._src import tree_util +from jax._src import util +from jax._src.pallas.fuser import fuser_utils + + +@dataclasses.dataclass +class CustomEvaluateSettings: + allow_transpose: bool = True + + +def evaluate(f, *, allow_transpose: bool = True): + def wrapped(*args, **kwargs): + jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs) + settings = CustomEvaluateSettings(allow_transpose=allow_transpose) + flat_args = tree_util.tree_leaves(args) + out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped + + +# Disallow most higher-order primitives for now. +disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p} + + +def _custom_evaluate_jaxpr( + settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args +): + def read(v: core.Atom) -> Any: + return v.val if isinstance(v, core.Literal) else env[v] + + def write(v: core.Var, val: Any) -> None: + env[v] = val + + env: dict[core.Var, Any] = {} + util.safe_map(write, jaxpr.constvars, consts) + util.safe_map(write, jaxpr.invars, args) + lu = core.last_used(jaxpr) + for eqn in jaxpr.eqns: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + + if eqn.primitive in disallowed_primitives: + raise NotImplementedError(f'Primitive {eqn.primitive} not supported.') + if not settings.allow_transpose and eqn.primitive is lax.transpose_p: + raise ValueError('Transpose not allowed.') + name_stack = ( + source_info_util.current_name_stack() + eqn.source_info.name_stack + ) + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack + ), eqn.ctx.manager: + ans = eqn.primitive.bind( + *subfuns, *util.safe_map(read, eqn.invars), **bind_params + ) + if eqn.primitive.multiple_results: + util.safe_map(write, eqn.outvars, ans) + else: + write(eqn.outvars[0], ans) + core.clean_up_dead_vars(eqn, env, lu) + return util.safe_map(read, jaxpr.outvars) diff --git a/jax/_src/pallas/fuser/fuser_utils.py b/jax/_src/pallas/fuser/fuser_utils.py new file mode 100644 index 000000000..ff44725bb --- /dev/null +++ b/jax/_src/pallas/fuser/fuser_utils.py @@ -0,0 +1,33 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic utils for fuser internals.""" +from jax._src import api_util +from jax._src import core +from jax._src import linear_util as lu +from jax._src import tree_util +from jax._src.interpreters import partial_eval as pe + + + +def make_jaxpr(f, *args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + flat_avals = [core.get_aval(x) for x in flat_args] + debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, debug_info=debug_info), in_tree + ) + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + return jaxpr, consts, in_tree, out_tree diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 28b62f4f0..729a447b7 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -18,6 +18,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec +from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate from jax._src.pallas.fuser.fusable import fusable as fusable from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse From 3a26804c684876fe1293f705f169b54d4ca5df18 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 11 Mar 2025 17:34:05 -0700 Subject: [PATCH 18/28] Rename `get_ty` to `typeof` which is an alias of `get_aval` PiperOrigin-RevId: 735946640 --- jax/__init__.py | 2 +- jax/_src/core.py | 2 +- tests/mutable_array_test.py | 4 ++-- tests/pjit_test.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 950c3ed4b..ae3bac4ad 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc Device = _xc.Device del _xc -from jax._src.core import get_ty as get_ty +from jax._src.core import typeof as typeof from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 diff --git a/jax/_src/core.py b/jax/_src/core.py index 9d8edeb8b..b17e26255 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1576,7 +1576,7 @@ def get_aval(x): return get_aval(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") -get_ty = get_aval +typeof = get_aval def is_concrete(x): return to_concrete_value(x) is not None diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index c510c2cfa..4c6a8eb7a 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase): @jax.jit def f(x_ref): - self.assertEqual(core.get_ty(x_ref).sharding.spec, - core.get_ty(x_ref[...]).sharding.spec) + self.assertEqual(core.typeof(x_ref).sharding.spec, + core.typeof(x_ref[...]).sharding.spec) y = x_ref[...] + 1 return y diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bd7954d60..4cd1af9d3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, s) def f(x): - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * 2 - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * x - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) return x # Eager mode From 66a6eb299e50284758fe77d40f3192ae4eae705d Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 11 Mar 2025 18:21:19 -0700 Subject: [PATCH 19/28] add autodiff rules for jax.lax.ragged_all_to_all collective also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future. PiperOrigin-RevId: 735957604 --- jax/_src/lax/lax.py | 2 +- jax/_src/lax/parallel.py | 212 +++++++++++++++++++++++--------- tests/ragged_collective_test.py | 74 +++++++++++ 3 files changed, 229 insertions(+), 59 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e6cbbf245..12706426b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8197,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + sharding=x_aval.sharding.with_spec(P())) _ones: Callable = partial(full_like, fill_value=1) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index b556042fe..764e4dcbe 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -22,6 +22,7 @@ from functools import partial import itertools import math +import jax from jax import tree_util from jax._src import core from jax._src import dispatch @@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, def ragged_all_to_all( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *, axis_name, axis_index_groups = None): - """Ragged version of :func:`all_to_all`. + """Ragged version of :func:`all_to_all` collective. - For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent - and the outermost (ragged) dimension. ``axis_index_groups`` is default to all - replicas (e.g. there is only one group and covers all axis indices). + We say data are "ragged" when they can be represented as a list of arrays + whose shapes differ only in the size of the leading axis. For example, these + data are ragged, comprising four component arrays:: - Ragged arrays are defined by a set of three arrays: - * ``data``: the ``data`` array is "ragged" along its outermost dimension, - along which each indexed element has variable size. - * ``offsets``: the ``offsets`` array indexes the outermost dimension of the - ``data`` array, and represents the starting offset of each ragged element of - the ``data`` array. - * ``sizes``: the ``sizes`` array represents the size of each ragged element of - the ``data`` array, where the size is specified in units of sub-elements. A - sub-element is defined as the suffix of the ``data`` array shape obtained by - removing the outermost "ragged" dimension. - The ``offsets`` and ``sizes`` arrays must have the same size. + ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)] - # Example ragged tensor - data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} - offsets: [3] = {0, 1, 4} - sizes: [3] = {1, 3, 4} + We often instead want a contiguous representation, e.g. for batching. But + because the shapes of the components differ, we can't apply ``jnp.stack`` to + represent these data by a single rectangular array with the leading axis + indexing the component arrays. So instead of stacking, we concatenate along + the leading axis and keep track of offsets and sizes. - # Index 'data' at 'offsets'[0], 'sizes'[0]' - {a,b,c} + That is, we can represent ragged data contiguously using a triple of dense + arrays ``(data, offsets, sizes)``: + * ``data``: the concatenated component arrays, + * ``offsets``: 1D array of indices into the leading axis of ``data`` + indicating where the data for each component array begins, + * ``sizes``: 1D array of sizes of the leading axis of each component array. + We refer to this triple as a ragged array. (Offsets can't be computed from + sizes in general to allow for internal padding.) - # Index 'data' at 'offsets'[1], 'sizes'[1]' - {d,e,f},{g,h,i},{j,k,l} + For example:: + data: f32[8,3] = jnp.array([ + [a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x], + ]) + offsets: i32[3] = jnp.array([0, 1, 4]) + sizes: i32[3] = jnp.array([1, 3, 4]) - # Index 'data' at 'offsets'[2], 'sizes'[2]' - {m,n,o},{p,q,r},{s,t,u},{v,w,x} + # To extract the first component array, of type f32[1,3] + data[offsets[0]:offsets[0]+sizes[0]] + # To extract the second component array, of type f32[3,3] + data[offsets[1]:offsets[1]+sizes[1]] - ``output_offsets`` must be sharded in a way that each replica has offsets in - the target replica output perspective. + # To extract the third component array, of type f32[4,3] + data[offsets[2]:offsets[2]+sizes[2]] - For i-th output offset, the current replica will send - `operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th - replica that will be written to - `output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th - replica ``output``. + The ``ragged_all_to_all`` collective operation communicates slices of ragged + arrays between devices. Each caller is both a sender and a receiver. The + ``input_offsets`` and ``send_sizes`` arguments indicate the slices of the + caller's ``operand`` to be sent. Received results are returned in an array + that has the same value of the argument ``output`` except with received values + written at some slices. The ``output_offsets`` argument does *not* indicate + the offsets at which all the received results are written; instead, + ``output_offsets`` indicates the offsets at which the *sent* slices are + written on their corresponding receivers. The sizes of received slices are + indicated by ``recv_sizes``. See below for details. - For example, if we have 2 replicas: + The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and + ``recv_sizes`` must all be the same length, and that length must be divisible + by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and + ``recv_sizes`` must satisfy:: - replica 0: - operand: [1, 2, 2] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 2] - output_offsets: [0, 0] - recv_sizes: [1, 1] + jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True)) - replica 1: - operand: [3, 4, 0] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 1] - output_offsets: [1, 2] - recv_sizes: [2, 1] + Specifically, given a call:: - replica 0's result will be: [1, 3, 0, 0] - replica 1's result will be: [2, 2, 4, 0] + result = ragged_all_to_all(operand, output, input_offsets, send_sizes, + output_offsets, recv_sizes, axis_name) + + the caller sends data like:: + + assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes) + N = len(input_offsets) + slices_per_device, leftover = divmod(N, lax.axis_size(axis_name)) + assert not leftover + + for i in range(N): + dst_idx = i // slices_per_device + SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]], + axis_name=axis_name, to_axis_index=dst_idx) + + and receives data in ``result`` like:: + + result = output + output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + for i in range(N): + src_idx = i // slices_per_device + result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i] + ].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx)) + + where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local + ``output_offsets`` does not indicate the offsets at which its local ``result`` + is updated; instead, it indicates where the corresponding sent slices are + written on their destination instances. To compute the local offsets at which + received data are written, we apply an ``all_to_all`` on ``output_offsets``. + + For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with + these arguments in each mapped function instance:: + + axis index 0: + operand = [1, 2, 2] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 2] + output_offsets = [0, 0] + recv_sizes = [1, 1] + + axis index 1: + operand = [3, 4, 0] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 1] + output_offsets = [1, 2] + recv_sizes = [2, 1] + + then:: + + axis index 0: + result = [1, 3, 0, 0] + + axis index 1: + result = [2, 2, 4, 0] Args: - operand: array with ragged dimension along its outermost dimension. - output: array of ragged input offsets. - input_offsets: array of ragged input send sizes. - send_sizes: array of ragged output data. - output_offsets: array of ragged offsets in the target replica output. - recv_sizes: array of ragged output receive sizes. - axis_name: hashable Python object used to name a pmapped axis (see the - :func:`jax.pmap` documentation for more details). + operand: data array of shape (N, A, B, ...) representing concatenated + (possibly padded) ragged data to be sent. + output: data array of shape (M, A, B, ...) to update with received data. + input_offsets: 1D integer array of shape (K,) representing the offsets of + leading-axis slices into ``operand`` to be sent. + send_sizes: 1D integer array array of shape (K,) representing the sizes of + leading-axis slices into ``operand`` to be sent. + output_offsets: 1D integer array of shape (K,) representing where the + corresponding sent data is written on each corresponding receiver. + recv_sizes: 1D integer array of shape (K,) representing sizes of + leading-axis slices into ``output`` to update with received data. + axis_name: name of the mapped axis over which to perform the communication. axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the first two and last two replicas). Groups must cover all axis indices @@ -538,7 +596,10 @@ def ragged_all_to_all( behavior is undefined. Returns: - array with shape equal to ``output``. + Array of shape (M, A, B, ...) with the same value as the ``output`` except + with received data written into slices starting at + ``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size + ``recv_sizes``. """ if not isinstance(axis_name, (tuple, list)): @@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval( effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects +def _ragged_all_to_all_jvp(primals, tangents, **params): + operand, output, *sizes_and_offsets = primals + operand_dot, output_dot, *_ = tangents + result = ragged_all_to_all_p.bind( + operand, output, *sizes_and_offsets, **params) + if type(operand_dot) is type(output_dot) is ad.Zero: + result_dot = ad.Zero.from_primal_value(result) + else: + operand_dot = ad.instantiate_zeros(operand_dot) + output_dot = ad.instantiate_zeros(output_dot) + result_dot = ragged_all_to_all_p.bind( + operand_dot, output_dot, *sizes_and_offsets, **params) + return result, result_dot + +def _ragged_all_to_all_transpose( + t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, + *, axis_name, axis_index_groups): + if type(t) is ad.Zero: + operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None + output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None + else: + zero = ad.zeros_like_aval(operand.aval) + output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True) + operand_t = ragged_all_to_all_p.bind( + t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, + axis_name=axis_name, axis_index_groups=axis_index_groups) + mask = jax.numpy.cumsum( + jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + .at[output_offsets_ + recv_sizes].add(-1)) + output_t = jax.numpy.where(mask, 0, t) + return [operand_t, output_t] + [None] * 4 + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) +ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp +ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 48f3d062b..844892adc 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase): c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2) + ), + ) + def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + operand = jax.device_put( + jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output = jax.device_put( + jnp.zeros((2, 4), dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + input_offsets = jax.device_put( + jnp.array([[0, 1], [0, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + send_sizes = jax.device_put( + jnp.array([[1, 2], [1, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output_offsets = jax.device_put( + jnp.array([[0, 0], [1, 2]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + recv_sizes = jax.device_put( + jnp.array([[1, 1], [2, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + operand = operand.reshape(operand.shape[1:]) + output = output.reshape(output.shape[1:]) + input_offsets = input_offsets.reshape(input_offsets.shape[1:]) + send_sizes = send_sizes.reshape(send_sizes.shape[1:]) + output_offsets = output_offsets.reshape(output_offsets.shape[1:]) + recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:]) + return lax.ragged_all_to_all( + operand, + output, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=axis_name, + ) + + args = input_offsets, send_sizes, output_offsets, recv_sizes + jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1) + @parameterized.named_parameters( dict( testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4) From ff751ecc7b77ba836ed5fa92a26828875f92bc5e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 11 Mar 2025 20:02:07 -0700 Subject: [PATCH 20/28] Run single python version for v4-8 and min & max for v5e-8 for TPU tests in nightly/release test workflow PiperOrigin-RevId: 735975004 --- .github/workflows/wheel_tests_nightly_release.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 4845fb0f2..b482f75e7 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -75,6 +75,18 @@ jobs: exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} + # Run a single Python version for v4-8. + - tpu-specs.type: "v4-8" + python: "3.10" + - tpu-specs.type: "v4-8" + python: "3.11" + - tpu-specs.type: "v4-8" + python: "3.12" + # Run min and max Python versions for v5e-8 + - tpu-specs.type: "v5e-8" + python: "3.11" + - tpu-specs.type: "v5e-8" + python: "3.12" name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.tpu-specs.runner }} From 74b4d868e3751c1b4efa315ff8cf771faeb0b663 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 11 Mar 2025 20:48:55 -0700 Subject: [PATCH 21/28] Add support for scratch buffers in `jax_triton`. This is required to use device-side TMA descriptors. PiperOrigin-RevId: 735985603 --- jaxlib/gpu/triton_kernels.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff90..e8a72d44e 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { param.value))); } } - // Triton's kernel ABI expects an additional scratchpad global memory. - // For now it is only used for on-device creation of TMA descriptors, which - // we do not use yet, so we are just replacing this argument with a null - // pointer. - // TODO: b/381242007 - Allocate a proper buffer if we want to use - // device-side TMA APIs. - void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns. - params.push_back(&scratch_ptr); - + params.push_back(buffers++); // Scratch buffer. return kernel_.Launch(stream, grid_, params.data()); } From 61ba2b2603535a7a3bf3754da531f36fa957fade Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 12 Mar 2025 04:51:33 -0700 Subject: [PATCH 22/28] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c270a6ce45df7f7bb3024f2e4df56b688d76ebd6. PiperOrigin-RevId: 736088162 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6710de12a..5bdf1f541 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240" -XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603" +XLA_COMMIT = "c270a6ce45df7f7bb3024f2e4df56b688d76ebd6" +XLA_SHA256 = "b2f7d0293fc62bb670d0b58c5847108652eac4d9e6c7e420bed2029e74af6f2d" def repo(): tf_http_archive( From a6ab6bbc20accd61c39f6c02ce160dee49a15d55 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 12 Mar 2025 05:19:55 -0700 Subject: [PATCH 23/28] Ignore Pallas TPU tests when testing with the oldest supported libtpu I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows PiperOrigin-RevId: 736094492 --- .github/workflows/pytest_tpu.yml | 2 ++ ci/run_pytest_tpu.sh | 11 ++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 2341bfb79..a105a2feb 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -135,6 +135,8 @@ jobs: echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV else echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" exit 1 diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 9b4a3bbfd..5d8aa9ed6 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -53,10 +53,19 @@ export JAX_SKIP_SLOW_TESTS=true echo "Running TPU tests..." if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then + # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic + # TPU does not guarantee anything about forward compatibility (unless + # jax.export is used) and the 12 week compatibility window accumulates way + # too many failures. + IGNORE_FLAGS="" + if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then + IGNORE_FLAGS="--ignore=tests/pallas" + fi + # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" tests examples + --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \ From d89835acbacec938971400d6fa54ea6dd5efe76c Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 12 Mar 2025 07:12:05 -0700 Subject: [PATCH 24/28] Fix matrix exclude syntax in TPU tests block Also, skip Python 3.13 for now due to missing dependency error. PiperOrigin-RevId: 736120590 --- .../workflows/wheel_tests_nightly_release.yml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index b482f75e7..574cc0628 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -65,7 +65,9 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10","3.11", "3.12", "3.13"] + # Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for + # profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302) + python: ["3.10", "3.11", "3.12"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, @@ -76,17 +78,16 @@ jobs: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} # Run a single Python version for v4-8. - - tpu-specs.type: "v4-8" + - tpu-specs: + type: "v4-8" python: "3.10" - - tpu-specs.type: "v4-8" + - tpu-specs: + type: "v4-8" python: "3.11" - - tpu-specs.type: "v4-8" - python: "3.12" # Run min and max Python versions for v5e-8 - - tpu-specs.type: "v5e-8" + - tpu-specs: + type: "v5e-8" python: "3.11" - - tpu-specs.type: "v5e-8" - python: "3.12" name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.tpu-specs.runner }} From e33f3fc48bcc6388b1b3d91db366c02c428aac05 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 12 Mar 2025 08:17:44 -0700 Subject: [PATCH 25/28] [pallas:mosaic_gpu] Added support for reductions to the WG lowering Note that * we have no easy way of testing multi-reductions at the moment; * `reduce_max` assumes WGMMA_ROW layout which is not currently supported by the dialect lowering AFAICT. PiperOrigin-RevId: 736138554 --- jax/_src/pallas/mosaic_gpu/lowering.py | 54 +++++++++++++++++++ .../mosaic/gpu/dialect_lowering.py | 42 ++++++++++++--- .../mosaic/gpu/fragmented_array.py | 2 +- .../mosaic/gpu/layout_inference.py | 6 +++ tests/pallas/mosaic_gpu_test.py | 29 +++++++++- 5 files changed, 124 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 43b9008bb..0333a9b03 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") +def _reduce_lowering_rule_wg( + kind: vector_dialect.CombiningKind, + acc: object, + ctx: LoweringRuleContext, + x, + *, + axes, +) -> ir.OpView: + [x_aval] = ctx.avals_in + [out_aval] = ctx.avals_out + x = _ensure_ir_value(x, x_aval.dtype) + out_type = mgpu_utils.dtype_to_ir_type(out_aval.dtype) + if not out_aval.shape: + # Special-case: reducing to a scalar. + if x_aval.ndim != 1: + # TODO(slebedev): Flatten to 1D, since vector.reduction only supports + # 1D inputs. + raise NotImplementedError("Only 1D inputs are supported") + return vector_dialect.ReductionOp(out_type, kind, x) + acc = vector_dialect.splat( + ir.VectorType.get(out_aval.shape, out_type), + _ensure_ir_value(acc, out_aval.dtype), + ) + return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) + + +@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + op = _reduce_lowering_rule_wg( + vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes + ) + op.attributes["offset"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes + ) + return op.result + + +@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + [x_aval] = ctx.avals_in + if jnp.issubdtype(x_aval.dtype, jnp.floating): + kind = vector_dialect.CombiningKind.MAXIMUMF + acc = float("-inf") + elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): + kind = vector_dialect.CombiningKind.MAXSI + acc = np.iinfo(x_aval.dtype).max + elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): + kind = vector_dialect.CombiningKind.MAXUI + acc = np.iinfo(x_aval.dtype).max + else: + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): i32 = ir.IntegerType.get_signless(32) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 560816319..55e8c5583 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -320,6 +320,34 @@ def _vector_splat_op_lowering_rule( return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)] +@_register_lowering(vector.ReductionOp) +def _vector_reduction_op_lowering_rule( + ctx: LoweringContext, op: vector.ReductionOp +) -> Sequence[ir.Value]: + del ctx # Unused. + [layout] = inference_utils.in_layouts(op) + () = inference_utils.out_layouts(op) + element_type = ir.VectorType(op.vector.type).element_type + is_signed = False if ir.IntegerType.isinstance(element_type) else None + a = _fragmented_array_from_ir(op.vector, layout, is_signed) + match str(op.kind): + case "#vector.kind": + smem = ir.Attribute.parse("#gpu.address_space") + scratch = _slice_smem( + ir.MemRefType.get([4], element_type, memory_space=smem), + arith.constant(None, op.attributes["offset"]), + ) + result = a.reduce_sum(scratch) + case ( + "#vector.kind" | "#vector.kind" | "#vector.kind" + ): + # TODO(slebedev): Implement this and remove the raise below. + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + case _: + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + return [_fragmented_array_to_ir(result, op.result.type)] + + def memref_layout_to_swizzle_and_transforms( layout: ir.Attribute, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: @@ -713,16 +741,17 @@ def _mgpu_slice_smem_op_lowering_rule( ctx: LoweringContext, op: SliceSMEMOp ) -> Sequence[ir.Value]: del ctx + return [_slice_smem(op.result.type, op.offset)] + + +def _slice_smem(result: ir.Type, offset: ir.Value): i8 = ir.IntegerType.get_signless(8) smem = ir.Attribute.parse("#gpu.address_space") - smem_base = gpu.dynamic_shared_memory( ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) - - offset = arith.index_cast(ir.IndexType.get(), op.offset) - - return [memref.view(op.result.type, smem_base, offset, [])] + offset = arith.index_cast(ir.IndexType.get(), offset) + return memref.view(result, smem_base, offset, []) @_register_lowering(scf.ForOp) @@ -866,7 +895,8 @@ def _should_lower(op: ir.OpView) -> bool: def lower_mgpu_dialect( - module: ir.Module, launch_context: launch_context.LaunchContext | None + module: ir.Module, + launch_context: launch_context.LaunchContext | None, ): # TODO(apaszke,bchetioui): Make sure the layouts match. # TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a52eb329d..5cacab511 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1389,7 +1389,7 @@ class FragmentedArray: if isinstance(self.layout, WGSplatFragLayout): [reg] = self.registers.flat if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.mulf + op = mulf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.muli else: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index cd493f0ab..f0f0998b8 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -336,6 +336,12 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: return [], [layout] +@partial(_add_layout_inference_rule, vector.ReductionOp) +def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: + if layout := inference_utils.value_layout(op.vector): + return [layout], [] + return None + @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0a3af26de..8a7b6c98b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -184,6 +184,23 @@ class PallasCallTest(PallasTest): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) + @parameterized.product( + shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics] + ) + def test_reduce_sum(self, shape, thread_semantics): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) + + x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), jnp.sum(x)) + def test_reshape(self): shape1, shape2 = (128,), (2, 16, 4) @@ -200,10 +217,14 @@ class PallasCallTest(PallasTest): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - def test_add_xy_indexed(self): + @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) + def test_add_xy_indexed(self, thread_semantics): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -1078,10 +1099,14 @@ class PallasCallTest(PallasTest): self.assertIn("acc % 2", output()) - def test_cond_returning_array(self): + @parameterized.parameters([*plgpu.ThreadSemantics]) + def test_cond_returning_array(self, thread_semantics): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), ) def kernel(x_ref, o_ref): acc = _sum_same_dtype(x_ref[...]) From 8b7cfcb33c4a6431aedb6793ef9b9179f8f336bb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 12 Mar 2025 08:21:16 -0700 Subject: [PATCH 26/28] Fix integer overflow in workspace size computations for experimental.rnn.*. PiperOrigin-RevId: 736139471 --- jaxlib/gpu/rnn.cc | 4 +++- jaxlib/gpu/rnn_kernels.cc | 5 +++-- jaxlib/gpu/rnn_kernels.h | 8 +++++--- tests/experimental_rnn_test.py | 32 ++++++++++++++++++++++++++++++-- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index c88b164e6..eaa815d33 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "jaxlib/absl_status_casters.h" @@ -29,7 +31,7 @@ namespace nb = nanobind; nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32, - int workspace_size, int reserve_space_size) { + size_t workspace_size, size_t reserve_space_size) { return PackDescriptor(RnnDescriptor{ input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size}); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 89a6d0a30..e9820bc31 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" +#include #include #include @@ -71,7 +72,7 @@ template <> namespace JAX_GPU_NAMESPACE { -static absl::StatusOr> +static absl::StatusOr> DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, @@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, return std::make_pair(workSpaceSize, reserveSpaceSize); } -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32) { diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 468c02eac..e95b77883 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_GPU_RNN_KERNELS_H_ #define JAXLIB_GPU_RNN_KERNELS_H_ +#include + #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -34,12 +36,12 @@ struct RnnDescriptor { float dropout; int bidirectional; int cudnn_allow_tf32; - int workspace_size; - int reserve_space_size; + size_t workspace_size; + size_t reserve_space_size; }; // Return (workspace size, reserve space size). -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 376a9b1a1..7fa3b93f3 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) + if jtu.jaxlib_version() <= (0, 5, 2): + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', + stablehlo) + else: + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', + stablehlo) + + @jtu.run_on_devices("cuda") + def test_no_workspace_overflow(self): + if jtu.jaxlib_version() <= (0, 5, 2): + self.skipTest("Older versions fail because of integer overflow.") + + # Problem sizes known to cause overflows on older versions. + batch_size, max_seq_length, input_size = 256, 500, 512 + num_layers, hidden_size = 1, 256 + num_params = rnn.get_num_params_in_lstm( + input_size, hidden_size, num_layers, True) + x = jax.ShapeDtypeStruct( + (batch_size, max_seq_length, input_size), jnp.float32) + h_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + c_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + weights = jax.ShapeDtypeStruct((num_params,), jnp.float32) + seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32) + fun = jax.jit(partial( + rnn.lstm, input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, dropout=0.0, bidirectional=True)) + fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash. + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From abcc7fdf4c18a2e20a31355c64fc767867703c1c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 12 Mar 2025 08:28:21 -0700 Subject: [PATCH 27/28] [sharding_in_types] Initial commit to add `varying_manual_axes: frozenset[AxisName]` to ShapedArray. Also add `jax_varying_axes_in_types` config to hide this option under while we develop it. PiperOrigin-RevId: 736141670 --- jax/_src/config.py | 9 +++++++++ jax/_src/core.py | 23 +++++++++++++++++------ jax/_src/pallas/core.py | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 00f65726a..1e46fb8bd 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,6 +235,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, + varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, @@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state( help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) +varying_axes_in_types = bool_state( + name='jax_varying_axes_in_types', + default=False, + help=('Adds varying manual axes to ShapedArray to track which mesh axes the' + ' array is varying over. This will help to remove the efficient' + ' transpose rewrite machinery in shard_map'), + include_in_jit_key=True) + data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index b17e26255..e53aec755 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape): class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent array_abstraction_level = 2 - def __init__(self, shape, dtype, weak_type=False, *, sharding=None): + def __init__(self, shape, dtype, weak_type=False, *, sharding=None, + varying_manual_axes: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) + if config.varying_axes_in_types.value: + self.varying_manual_axes = varying_manual_axes def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding + if 'varying_manual_axes' not in kwargs: + kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', + frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and self.sharding == other.sharding) + and self.sharding == other.sharding + and (getattr(self, 'varying_manual_axes', frozenset()) == + getattr(other, 'varying_manual_axes', frozenset()))) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, self.sharding)) + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + getattr(self, 'varying_manual_axes', frozenset()))) def to_tangent_aval(self): - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding) + return ShapedArray( + self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, sharding=self.sharding, + varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index ad6ce6ab4..466f6037a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -343,7 +343,7 @@ class BlockSpec: if self.block_shape is None: block_shape = array_aval.shape else: - block_shape = self.block_shape + block_shape = self.block_shape # type: ignore if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) " From db8ba1b598defcb2b02e69134f38c81f3a73f2ce Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Mar 2025 17:06:35 +0000 Subject: [PATCH 28/28] Change to run CI --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0aca7cf58..c2e8657a6 100644 --- a/README.md +++ b/README.md @@ -456,3 +456,4 @@ For details about the JAX API, see the For getting started as a JAX developer, see the [developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +