diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt new file mode 100644 index 000000000..54f5d2ab3 --- /dev/null +++ b/build/gpu-test-requirements.txt @@ -0,0 +1,13 @@ +# NVIDIA CUDA dependencies +# Note that the wheels are downloaded only when the targets in bazel command +# contain dependencies on these wheels. +nvidia-cublas-cu12>=12.1.3.1 +nvidia-cuda-cupti-cu12>=12.1.105 +nvidia-cuda-nvcc-cu12>=12.6.85 +nvidia-cuda-runtime-cu12>=12.1.105 +nvidia-cudnn-cu12>=9.1,<10.0 +nvidia-cufft-cu12>=11.0.2.54 +nvidia-cusolver-cu12>=11.4.5.107 +nvidia-cusparse-cu12>=12.1.0.106 +nvidia-nccl-cu12>=2.18.1 +nvidia-nvjitlink-cu12>=12.1.105 diff --git a/build/requirements.in b/build/requirements.in index d4e13d943..ec7fc71b0 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -2,6 +2,7 @@ # test deps # -r test-requirements.txt +-r gpu-test-requirements.txt # # build deps diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 290c7e732..dd7b6a55f 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -380,6 +380,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # via -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # via -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # via -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index f73065950..656458004 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index feebc33dc..2f5a25e2b 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 0a32888f6..f9635ec2b 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -434,6 +434,64 @@ numpy==2.1.2 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index dfefaf042..129874cba 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/jax/BUILD b/jax/BUILD index fb2e59864..3f75d60a0 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -14,7 +14,7 @@ # JAX is Autograd and XLA -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", @@ -45,17 +45,26 @@ package( licenses(["notice"]) -# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and -# assume it has been installed, e.g., by `pip`. -bool_flag( +# The flag controls whether jaxlib should be built by Bazel. +# If ":build_jaxlib=true", then jaxlib will be built. +# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel +# is available in the "dist" folder. +# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( name = "build_jaxlib", - build_setting_default = True, + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], ) config_setting( name = "enable_jaxlib_build", flag_values = { - ":build_jaxlib": "True", + ":build_jaxlib": "true", }, ) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 79aebcd86..1f4e5a08d 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -49,7 +49,7 @@ py_library_providing_imports_info( config_setting( name = "disable_jaxlib_for_cpu_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "False", }, ) @@ -57,7 +57,23 @@ config_setting( config_setting( name = "disable_jaxlib_for_cuda12_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "True", }, -) \ No newline at end of file +) + +config_setting( + name = "enable_py_import_for_cpu_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "False", + }, +) + +config_setting( + name = "enable_py_import_for_cuda12_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "True", + }, +) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 58a83d9b0..c4ac8d00f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -224,7 +224,15 @@ def if_building_jaxlib( "@pypi_jax_cuda12_plugin//:pkg", "@pypi_jax_cuda12_pjrt//:pkg", ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]): + if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], + if_py_import = [ + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], + if_py_import_for_cpu = [ + "//jaxlib/tools:jaxlib_py_import", + ]): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. @@ -234,12 +242,16 @@ def if_building_jaxlib( if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of gpu-enabled builds if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds + if_py_import: the py_import targets to depend on in case of gpu-enabled builds + if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, }) # buildifier: disable=function-docstring diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index baf996d50..5b24d2359 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load( "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", @@ -228,6 +232,18 @@ string_flag( build_setting_default = "dist", ) +NVIDIA_WHEELS_DEPS = [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", +] + jax_wheel( name = "jaxlib_wheel", no_abi = False, @@ -235,6 +251,11 @@ jax_wheel( wheel_name = "jaxlib", ) +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + jax_wheel( name = "jaxlib_wheel_editable", editable = True, @@ -252,6 +273,12 @@ jax_wheel( wheel_name = "jax_cuda12_plugin", ) +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda_plugin_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, @@ -290,6 +317,12 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda_pjrt_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True,