mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add the nightly dev version to __version__
of jaxlib.
PiperOrigin-RevId: 448001375
This commit is contained in:
parent
90f926ac6b
commit
46d034baab
@ -18,10 +18,13 @@
|
||||
# Most users should not run this script directly; use build.py instead.
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import functools
|
||||
import glob
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@ -274,6 +277,22 @@ def prepare_wheel(sources_path):
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
||||
def edit_jaxlib_version(sources_path):
|
||||
version_regex = re.compile(r'__version__ = \"(.*)\"')
|
||||
|
||||
version_file = pathlib.Path(sources_path) / "jaxlib" / "version.py"
|
||||
content = version_file.read_text()
|
||||
|
||||
version_num = version_regex.search(content).group(1)
|
||||
|
||||
datestring = datetime.datetime.now().strftime('%Y%m%d')
|
||||
nightly_version = f'{version_num}.dev{datestring}'
|
||||
|
||||
content = content.replace(f'__version__ = "{version_num}"',
|
||||
f'__version__ = "{nightly_version}"')
|
||||
version_file.write_text(content)
|
||||
|
||||
|
||||
def build_wheel(sources_path, output_path, cpu):
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
platform_name, cpu_name = {
|
||||
@ -288,6 +307,8 @@ def build_wheel(sources_path, output_path, cpu):
|
||||
f"{sys.version_info.minor}")
|
||||
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
|
||||
cwd = os.getcwd()
|
||||
if os.environ.get('JAXLIB_NIGHTLY'):
|
||||
edit_jaxlib_version(sources_path)
|
||||
os.chdir(sources_path)
|
||||
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
|
||||
python_tag_arg, platform_tag_arg], check=True)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
from setuptools import setup
|
||||
import os
|
||||
|
||||
@ -30,9 +29,6 @@ if cuda_version and cudnn_version:
|
||||
nightly = os.environ.get('JAXLIB_NIGHTLY')
|
||||
if nightly:
|
||||
project_name = 'jaxlib-nightly'
|
||||
# Version as `X.Y.Z.dev20220510`
|
||||
datestring = datetime.datetime.now().strftime('%Y%m%d')
|
||||
__version__ = f'{__version__}.dev{datestring}'
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
|
Loading…
x
Reference in New Issue
Block a user