mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
improve jaxlib wheel building script
This commit is contained in:
parent
f50996df64
commit
b907bcd5ae
@ -3,28 +3,31 @@ set -xev
|
||||
JAXLIB_VERSION=$(sed -n "s/^ \+version=[']\(.*\)['],$/\\1/p" jax/build/setup.py)
|
||||
|
||||
PYTHON_VERSIONS="py2 py3"
|
||||
CUDA_VERSIONS="9.2" # "9.2 10.0"
|
||||
CUDA_VERSIONS="9.0 9.2 10.0"
|
||||
CUDA_VARIANTS="cuda" # "cuda cuda-included"
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
# build the pypi linux packages, tagging with manylinux1 for pypi reasons
|
||||
docker build -t jaxbuild jax/build/
|
||||
for PYTHON_VERSION in $PYTHON_VERSIONS
|
||||
do
|
||||
mkdir -p dist/nocuda/
|
||||
nvidia-docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION nocuda
|
||||
mv dist/*.whl dist/nocuda/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-manylinux1_x86_64.whl
|
||||
done
|
||||
|
||||
# build the cuda linux packages, tagging with linux_x86_64
|
||||
for CUDA_VERSION in $CUDA_VERSIONS
|
||||
do
|
||||
docker build -t jaxbuild jax/build/ --build-arg CUDA_VERSION=$CUDA_VERSION
|
||||
|
||||
for PYTHON_VERSION in $PYTHON_VERSIONS
|
||||
do
|
||||
mkdir -p dist/nocuda/
|
||||
nvidia-docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION nocuda
|
||||
mv dist/*.whl dist/nocuda/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-manylinux1_x86_64.whl
|
||||
|
||||
for CUDA_VARIANT in $CUDA_VARIANTS
|
||||
do
|
||||
mkdir -p dist/cuda${CUDA_VERSION//.}
|
||||
mkdir -p dist/${CUDA_VARIANT}${CUDA_VERSION//.}
|
||||
nvidia-docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT
|
||||
mv dist/*.whl dist/cuda${CUDA_VERSION//.}/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-linux_x86_64.whl
|
||||
mv dist/*.whl dist/${CUDA_VARIANT}${CUDA_VERSION//.}/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-linux_x86_64.whl
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "now you might want to run something like:"
|
||||
echo "python3 -m twine upload --repository-url https://test.pypi.org/legacy/ dist/nocuda/*.whl --verbose"
|
||||
|
@ -57,7 +57,6 @@ def grad(fun, argnums=0):
|
||||
g = vjp_py(onp.ones((), onp.result_type(ans)))
|
||||
return g[0] if isinstance(argnums, int) else g
|
||||
|
||||
|
||||
return grad_f
|
||||
|
||||
@curry
|
||||
|
Loading…
x
Reference in New Issue
Block a user