From 5eff830f0ec04eb4b97f67527da8a80276a2db39 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Apr 2019 08:21:22 -0700 Subject: [PATCH] Move jaxlib version.py into jaxlib, and install it in build/jaxlib as build action. Update jaxlib version check to look in jaxlib.version. --- build/BUILD.bazel | 1 + build/install_xla_in_source_tree.sh | 2 ++ jax/lib/xla_bridge.py | 7 ++++--- jaxlib/BUILD | 5 +++++ {build/jaxlib => jaxlib}/version.py | 0 5 files changed, 12 insertions(+), 3 deletions(-) rename {build/jaxlib => jaxlib}/version.py (100%) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index f70cb3975..c78a1a50a 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -23,6 +23,7 @@ sh_binary( srcs = ["install_xla_in_source_tree.sh"], data = [ "@org_tensorflow//tensorflow/compiler/xla/python:xla_client", + "//jaxlib", "//jaxlib:lapack.so", ], deps = ["@bazel_tools//tools/bash/runfiles"], diff --git a/build/install_xla_in_source_tree.sh b/build/install_xla_in_source_tree.sh index ee16fa467..991b794f0 100755 --- a/build/install_xla_in_source_tree.sh +++ b/build/install_xla_in_source_tree.sh @@ -53,6 +53,7 @@ fi # Copy the XLA dependencies into jax/lib, fixing up some imports to point to the # new location. cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib" +cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib" cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \ "${TARGET}/jaxlib" sed \ @@ -65,3 +66,4 @@ sed \ -e 's/from tensorflow.compiler.xla.python import xla_extension as _xla/from . import xla_extension as _xla/' \ < "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xrt.py)" \ > "${TARGET}/jaxlib/xrt.py" + diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index b0e60eec1..05927b261 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -32,13 +32,14 @@ import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy import jaxlib + # Check the jaxlib version before importing anything else from jaxlib. def _check_jaxlib_version(): minimum_version = (0, 1, 11) - if hasattr(jaxlib, '__version__'): - version = tuple(int(x) for x in jaxlib.__version__.split('.')) + if hasattr(jaxlib, 'version'): + version = tuple(int(x) for x in jaxlib.version.__version__.split('.')) else: - version = (0, 1, 9) # The version before jaxlib.__version__ was added. + version = (0, 1, 9) # The version before jaxlib.version was added. if version < minimum_version: msg = 'jaxlib is version {}, but this version of jax requires version {}.' raise ValueError(msg.format('.'.join(map(str, version)), diff --git a/jaxlib/BUILD b/jaxlib/BUILD index cf16b5406..c46cf828c 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -25,3 +25,8 @@ pyx_library( srcs = ["lapack.pyx"], py_deps = ["@org_tensorflow//third_party/py/numpy"], ) + +py_library( + name = "jaxlib", + srcs = ["version.py"], +) \ No newline at end of file diff --git a/build/jaxlib/version.py b/jaxlib/version.py similarity index 100% rename from build/jaxlib/version.py rename to jaxlib/version.py