rocm_jax/build/install_xla_in_source_tree.sh
Peter Hawkins ed3e2308c1 Add support for linear algebra ops on GPU using Cusolver:
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.

Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
2019-08-02 11:16:15 -04:00

75 lines
3.1 KiB
Bash
Executable File

#!/bin/bash
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that installs JAX's XLA dependencies inside the JAX source tree.
# --- begin runfiles.bash initialization ---
# Copy-pasted from Bazel's Bash runfiles library (tools/bash/runfiles/runfiles.bash).
set -euo pipefail
if [[ ! -d "${RUNFILES_DIR:-/dev/null}" && ! -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
if [[ -f "$0.runfiles_manifest" ]]; then
export RUNFILES_MANIFEST_FILE="$0.runfiles_manifest"
elif [[ -f "$0.runfiles/MANIFEST" ]]; then
export RUNFILES_MANIFEST_FILE="$0.runfiles/MANIFEST"
elif [[ -f "$0.runfiles/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
export RUNFILES_DIR="$0.runfiles"
fi
fi
if [[ -f "${RUNFILES_DIR:-/dev/null}/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
source "${RUNFILES_DIR}/bazel_tools/tools/bash/runfiles/runfiles.bash"
elif [[ -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
source "$(grep -m1 "^bazel_tools/tools/bash/runfiles/runfiles.bash " \
"$RUNFILES_MANIFEST_FILE" | cut -d ' ' -f 2-)"
else
echo >&2 "ERROR: cannot find @bazel_tools//tools/bash/runfiles:runfiles.bash"
exit 1
fi
# --- end runfiles.bash initialization ---
if [[ $# -ne 1 ]]; then
echo "Usage: $0 <target directory>"
exit 1
fi
TARGET="$1"
if [[ ! -d "${TARGET}/jaxlib" ]]; then
echo "Target directory ${TARGET} does not have a jaxlib directory"
exit 1
fi
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
# new location.
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
fi
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
"${TARGET}/jaxlib"
sed \
-e 's/from tensorflow.compiler.xla.python import xla_extension as _xla/from . import xla_extension as _xla/' \
-e 's/from tensorflow.compiler.xla.python.xla_extension import ops/from .xla_extension import ops/' \
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
> "${TARGET}/jaxlib/xla_client.py"
sed \
-e 's/from tensorflow.compiler.xla.python import xla_client/from . import xla_client/' \
-e 's/from tensorflow.compiler.xla.python import xla_extension as _xla/from . import xla_extension as _xla/' \
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xrt.py)" \
> "${TARGET}/jaxlib/xrt.py"