mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add nightly __version__
string if building jaxlib nightly
PiperOrigin-RevId: 447822974
This commit is contained in:
parent
7d27343506
commit
dfb2caf31e
@ -12,10 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
from setuptools import setup
|
||||
import os
|
||||
|
||||
__version__ = None
|
||||
project_name = 'jaxlib'
|
||||
|
||||
with open('jaxlib/version.py') as f:
|
||||
exec(f.read(), globals())
|
||||
@ -25,8 +27,15 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
||||
if cuda_version and cudnn_version:
|
||||
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
|
||||
|
||||
nightly = os.environ.get('JAXLIB_NIGHTLY')
|
||||
if nightly:
|
||||
project_name = 'jaxlib-nightly'
|
||||
# Version as `X.Y.Z.dev20220510`
|
||||
datestring = datetime.datetime.now().strftime('%Y%m%d')
|
||||
__version__ = f'{__version__}.dev{datestring}'
|
||||
|
||||
setup(
|
||||
name='jaxlib',
|
||||
name=project_name,
|
||||
version=__version__,
|
||||
description='XLA library for JAX',
|
||||
author='JAX team',
|
||||
|
@ -41,7 +41,7 @@ T = lambda x: np.swapaxes(x, -1, -2)
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
|
||||
jaxlib_version = tuple(map(int, jax.lib.__version__.split('.')))
|
||||
jaxlib_version = jax._src.lib.version
|
||||
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user