Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.

This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
This commit is contained in:
jax authors 2025-03-11 08:29:45 -07:00
parent 7fd32ecc04
commit 1aca76fc13
11 changed files with 384 additions and 10 deletions

View File

@ -0,0 +1,13 @@
# NVIDIA CUDA dependencies
# Note that the wheels are downloaded only when the targets in bazel command
# contain dependencies on these wheels.
nvidia-cublas-cu12>=12.1.3.1
nvidia-cuda-cupti-cu12>=12.1.105
nvidia-cuda-nvcc-cu12>=12.6.85
nvidia-cuda-runtime-cu12>=12.1.105
nvidia-cudnn-cu12>=9.1,<10.0
nvidia-cufft-cu12>=11.0.2.54
nvidia-cusolver-cu12>=11.4.5.107
nvidia-cusparse-cu12>=12.1.0.106
nvidia-nccl-cu12>=2.18.1
nvidia-nvjitlink-cu12>=12.1.105

View File

@ -2,6 +2,7 @@
# test deps
#
-r test-requirements.txt
-r gpu-test-requirements.txt
#
# build deps

View File

@ -380,6 +380,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# via -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# via -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# via -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -375,6 +375,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -434,6 +434,64 @@ numpy==2.1.2 ; python_version >= "3.13" \
# matplotlib
# ml-dtypes
# scipy
nvidia-cublas-cu12==12.8.3.14 \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac

View File

@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \
# matplotlib
# ml-dtypes
# scipy
nvidia-cublas-cu12==12.8.3.14 \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac

View File

@ -14,7 +14,7 @@
# JAX is Autograd and XLA
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
@ -45,17 +45,26 @@ package(
licenses(["notice"])
# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and
# assume it has been installed, e.g., by `pip`.
bool_flag(
# The flag controls whether jaxlib should be built by Bazel.
# If ":build_jaxlib=true", then jaxlib will be built.
# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel
# is available in the "dist" folder.
# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute.
# The py_import rule unpacks the wheel and provides its content as a py_library.
string_flag(
name = "build_jaxlib",
build_setting_default = True,
build_setting_default = "true",
values = [
"true",
"false",
"wheel",
],
)
config_setting(
name = "enable_jaxlib_build",
flag_values = {
":build_jaxlib": "True",
":build_jaxlib": "true",
},
)

View File

@ -49,7 +49,7 @@ py_library_providing_imports_info(
config_setting(
name = "disable_jaxlib_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "False",
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "False",
},
)
@ -57,7 +57,23 @@ config_setting(
config_setting(
name = "disable_jaxlib_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "False",
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "True",
},
)
)
config_setting(
name = "enable_py_import_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "False",
},
)
config_setting(
name = "enable_py_import_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "True",
},
)

View File

@ -224,7 +224,15 @@ def if_building_jaxlib(
"@pypi_jax_cuda12_plugin//:pkg",
"@pypi_jax_cuda12_pjrt//:pkg",
],
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]):
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"],
if_py_import = [
"//jaxlib/tools:jaxlib_py_import",
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
],
if_py_import_for_cpu = [
"//jaxlib/tools:jaxlib_py_import",
]):
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
@ -234,12 +242,16 @@ def if_building_jaxlib(
if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of
gpu-enabled builds
if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds
if_py_import: the py_import targets to depend on in case of gpu-enabled builds
if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds
"""
return select({
"//jax:enable_jaxlib_build": if_building,
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu,
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building,
"//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu,
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import,
})
# buildifier: disable=function-docstring

View File

@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load(
"@xla//third_party/py:py_import.bzl",
"py_import",
)
load(
"@xla//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
@ -228,6 +232,18 @@ string_flag(
build_setting_default = "dist",
)
NVIDIA_WHEELS_DEPS = [
"@pypi_nvidia_cublas_cu12//:whl",
"@pypi_nvidia_cuda_cupti_cu12//:whl",
"@pypi_nvidia_cuda_runtime_cu12//:whl",
"@pypi_nvidia_cudnn_cu12//:whl",
"@pypi_nvidia_cufft_cu12//:whl",
"@pypi_nvidia_cusolver_cu12//:whl",
"@pypi_nvidia_cusparse_cu12//:whl",
"@pypi_nvidia_nccl_cu12//:whl",
"@pypi_nvidia_nvjitlink_cu12//:whl",
]
jax_wheel(
name = "jaxlib_wheel",
no_abi = False,
@ -235,6 +251,11 @@ jax_wheel(
wheel_name = "jaxlib",
)
py_import(
name = "jaxlib_py_import",
wheel = ":jaxlib_wheel",
)
jax_wheel(
name = "jaxlib_wheel_editable",
editable = True,
@ -252,6 +273,12 @@ jax_wheel(
wheel_name = "jax_cuda12_plugin",
)
py_import(
name = "jax_cuda_plugin_py_import",
wheel = ":jax_cuda_plugin_wheel",
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)
jax_wheel(
name = "jax_cuda_plugin_wheel_editable",
editable = True,
@ -290,6 +317,12 @@ jax_wheel(
wheel_name = "jax_cuda12_pjrt",
)
py_import(
name = "jax_cuda_pjrt_py_import",
wheel = ":jax_cuda_pjrt_wheel",
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)
jax_wheel(
name = "jax_cuda_pjrt_wheel_editable",
editable = True,