mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move jaxlib version.py into jaxlib, and install it in build/jaxlib as build action.
Update jaxlib version check to look in jaxlib.version.
This commit is contained in:
parent
650db36d30
commit
5eff830f0e
@ -23,6 +23,7 @@ sh_binary(
|
||||
srcs = ["install_xla_in_source_tree.sh"],
|
||||
data = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
"//jaxlib",
|
||||
"//jaxlib:lapack.so",
|
||||
],
|
||||
deps = ["@bazel_tools//tools/bash/runfiles"],
|
||||
|
@ -53,6 +53,7 @@ fi
|
||||
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
|
||||
# new location.
|
||||
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
|
||||
"${TARGET}/jaxlib"
|
||||
sed \
|
||||
@ -65,3 +66,4 @@ sed \
|
||||
-e 's/from tensorflow.compiler.xla.python import xla_extension as _xla/from . import xla_extension as _xla/' \
|
||||
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xrt.py)" \
|
||||
> "${TARGET}/jaxlib/xrt.py"
|
||||
|
||||
|
@ -32,13 +32,14 @@ import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
|
||||
|
||||
import jaxlib
|
||||
|
||||
|
||||
# Check the jaxlib version before importing anything else from jaxlib.
|
||||
def _check_jaxlib_version():
|
||||
minimum_version = (0, 1, 11)
|
||||
if hasattr(jaxlib, '__version__'):
|
||||
version = tuple(int(x) for x in jaxlib.__version__.split('.'))
|
||||
if hasattr(jaxlib, 'version'):
|
||||
version = tuple(int(x) for x in jaxlib.version.__version__.split('.'))
|
||||
else:
|
||||
version = (0, 1, 9) # The version before jaxlib.__version__ was added.
|
||||
version = (0, 1, 9) # The version before jaxlib.version was added.
|
||||
if version < minimum_version:
|
||||
msg = 'jaxlib is version {}, but this version of jax requires version {}.'
|
||||
raise ValueError(msg.format('.'.join(map(str, version)),
|
||||
|
@ -25,3 +25,8 @@ pyx_library(
|
||||
srcs = ["lapack.pyx"],
|
||||
py_deps = ["@org_tensorflow//third_party/py/numpy"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jaxlib",
|
||||
srcs = ["version.py"],
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user