mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add dummy binary build target, move WORKSPACE up
This commit is contained in:
parent
c03e5e80c5
commit
d347d65c5c
@ -8,8 +8,6 @@ then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cp build/WORKSPACE .
|
||||
|
||||
tmp=/tmp/jax-build # could mktemp -d but this way we can cache results
|
||||
mkdir -p ${tmp}
|
||||
|
||||
@ -28,7 +26,10 @@ export PATH="${bazel_dir}/bin:$PATH"
|
||||
handle_temporary_bazel_0_19_1_bug=1 # TODO(mattjj): remove with bazel 0.19.2
|
||||
|
||||
## get and configure tensorflow for building xla with gpu support
|
||||
git clone https://github.com/tensorflow/tensorflow.git
|
||||
if [[ ! -d tensorflow ]]
|
||||
then
|
||||
git clone https://github.com/tensorflow/tensorflow.git
|
||||
fi
|
||||
pushd tensorflow
|
||||
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-$(which python)}
|
||||
export PYTHON_LIB_PATH=${SP_DIR:-$(python -m site --user-site)}
|
||||
@ -65,14 +66,14 @@ bazel_build_opt="-c opt --config=cuda"
|
||||
if [ -n $handle_temporary_bazel_0_19_1_bug ]
|
||||
then
|
||||
set +e
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} examples:interactive 2> /dev/null
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax 2> /dev/null
|
||||
sed -i 's/toolchain_identifier = "local"/toolchain_identifier = "local_linux"/' ${bazel_output_base}/external/local_config_cc/BUILD
|
||||
set -e
|
||||
fi
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} examples:interactive
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax
|
||||
|
||||
## extract the pieces we need
|
||||
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/examples/interactive.runfiles/org_tensorflow/tensorflow"
|
||||
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/jax/build_jax.runfiles/org_tensorflow/tensorflow"
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
|
||||
@ -84,6 +85,5 @@ sed -i '/from tensorflow.compiler.xla.service import hlo_pb2/d' jax/lib/xla_clie
|
||||
|
||||
## clean up
|
||||
rm -f bazel-* # symlinks
|
||||
rm -f WORKSPACE
|
||||
rm -rf tensorflow
|
||||
# rm -rf ${tmp}
|
||||
|
@ -1,4 +1,6 @@
|
||||
# JAX is Autograd and XLA
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "libjax",
|
||||
srcs = glob(
|
||||
@ -36,3 +38,10 @@ py_library(
|
||||
srcs = ["experimental/lapax.py"],
|
||||
deps = [":libjax"],
|
||||
)
|
||||
|
||||
# this is a dummy target for building purposes
|
||||
py_binary(
|
||||
name = "build_jax",
|
||||
srcs = ["core.py"],
|
||||
deps = [":libjax"],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user