rocm_jax/docs/build_custom_gpu.sh
Frederic Bastien 83ffcc9c7d Current status + build script fixes
Add print

First version with custom_partitioning. The communication during the gradient aren't optimal.

Fix the gradient sharding

small update

Fix the strange replicated computation.

Make it work with the new JAX version.

Add the structure for custom_p domentation.

Small clean up

First version of the doc

Add comment and typing annotation

tab->space

Simplify code and add docstring

Use the simpler JAX API since 0.4.16 (August 2023).

Custom partitioning using custom_partitioning

updated docs; dump custom_partitioning HLO

doc update

more documentation updates; include links to code instead of inlined code

fix typos

fix more typos

fix type annotations in source and update docs

minor fixes

import fix

lint fix

added apache license header
2024-04-08 04:48:28 -07:00

14 lines
1.3 KiB
Bash

python -m pip install pybind11==2.10.1
mkdir -p build
touch build/__init__.py
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
#python_include_path=$(python -c 'from distutils.sysconfig import get_python_inc;print(get_python_inc())')
echo pybind_include_path=$pybind_include_path
echo python_executable=$python_executable
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}3-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}3-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}3-config --extension-suffix)