Move jaxlib version.py into jaxlib, and install it in build/jaxlib as build action.

Update jaxlib version check to look in jaxlib.version.
This commit is contained in:
Peter Hawkins 2019-04-01 08:21:22 -07:00
parent 650db36d30
commit 5eff830f0e
5 changed files with 12 additions and 3 deletions

View File

@ -23,6 +23,7 @@ sh_binary(
srcs = ["install_xla_in_source_tree.sh"],
data = [
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"//jaxlib",
"//jaxlib:lapack.so",
],
deps = ["@bazel_tools//tools/bash/runfiles"],

View File

@ -53,6 +53,7 @@ fi
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
# new location.
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
"${TARGET}/jaxlib"
sed \
@ -65,3 +66,4 @@ sed \
-e 's/from tensorflow.compiler.xla.python import xla_extension as _xla/from . import xla_extension as _xla/' \
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xrt.py)" \
> "${TARGET}/jaxlib/xrt.py"

View File

@ -32,13 +32,14 @@ import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
import jaxlib
# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
minimum_version = (0, 1, 11)
if hasattr(jaxlib, '__version__'):
version = tuple(int(x) for x in jaxlib.__version__.split('.'))
if hasattr(jaxlib, 'version'):
version = tuple(int(x) for x in jaxlib.version.__version__.split('.'))
else:
version = (0, 1, 9) # The version before jaxlib.__version__ was added.
version = (0, 1, 9) # The version before jaxlib.version was added.
if version < minimum_version:
msg = 'jaxlib is version {}, but this version of jax requires version {}.'
raise ValueError(msg.format('.'.join(map(str, version)),

View File

@ -25,3 +25,8 @@ pyx_library(
srcs = ["lapack.pyx"],
py_deps = ["@org_tensorflow//third_party/py/numpy"],
)
py_library(
name = "jaxlib",
srcs = ["version.py"],
)