mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Copybara import of the project:
-- 57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>: [Roll forward] Update required Python version to 3.9 PiperOrigin-RevId: 542728213
This commit is contained in:
parent
19890086fa
commit
fc0dcd15a2
@ -74,8 +74,8 @@ def get_python_version(python_bin_path):
|
||||
return major, minor
|
||||
|
||||
def check_python_version(python_version):
|
||||
if python_version < (3, 8):
|
||||
print("ERROR: JAX requires Python 3.8 or newer, found ", python_version)
|
||||
if python_version < (3, 9):
|
||||
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ setup(
|
||||
author='JAX team',
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.8',
|
||||
python_requires='>=3.9',
|
||||
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.1.0'],
|
||||
extras_require={
|
||||
'cuda11_pip': [
|
||||
|
2
setup.py
2
setup.py
@ -61,7 +61,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
|
||||
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
|
||||
python_requires='>=3.8',
|
||||
python_requires='>=3.9',
|
||||
install_requires=[
|
||||
'ml_dtypes>=0.1.0',
|
||||
'numpy>=1.21',
|
||||
|
@ -79,7 +79,6 @@ config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
def _check_instance(self, x):
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
from functools import partial
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
@ -35,8 +34,6 @@ config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
python_version = (sys.version_info[0], sys.version_info[1])
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
|
||||
class DynamicShapeStagingTest(jtu.JaxTestCase):
|
||||
def test_basic_staging(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user