Add nightly __version__ string if building jaxlib nightly

PiperOrigin-RevId: 447822974
This commit is contained in:
Yash Katariya 2022-05-10 14:05:09 -07:00 committed by jax authors
parent 7d27343506
commit dfb2caf31e
2 changed files with 11 additions and 2 deletions

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from setuptools import setup
import os
__version__ = None
project_name = 'jaxlib'
with open('jaxlib/version.py') as f:
exec(f.read(), globals())
@ -25,8 +27,15 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
if cuda_version and cudnn_version:
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
nightly = os.environ.get('JAXLIB_NIGHTLY')
if nightly:
project_name = 'jaxlib-nightly'
# Version as `X.Y.Z.dev20220510`
datestring = datetime.datetime.now().strftime('%Y%m%d')
__version__ = f'{__version__}.dev{datestring}'
setup(
name='jaxlib',
name=project_name,
version=__version__,
description='XLA library for JAX',
author='JAX team',

View File

@ -41,7 +41,7 @@ T = lambda x: np.swapaxes(x, -1, -2)
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
jaxlib_version = tuple(map(int, jax.lib.__version__.split('.')))
jaxlib_version = jax._src.lib.version
class NumpyLinalgTest(jtu.JaxTestCase):