Copybara import of the project:

--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
This commit is contained in:
Yash Katariya 2023-06-22 18:57:41 -07:00 committed by jax authors
parent 19890086fa
commit fc0dcd15a2
5 changed files with 4 additions and 8 deletions

View File

@ -74,8 +74,8 @@ def get_python_version(python_bin_path):
return major, minor
def check_python_version(python_version):
if python_version < (3, 8):
print("ERROR: JAX requires Python 3.8 or newer, found ", python_version)
if python_version < (3, 9):
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
sys.exit(-1)

View File

@ -49,7 +49,7 @@ setup(
author='JAX team',
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.8',
python_requires='>=3.9',
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.1.0'],
extras_require={
'cuda11_pip': [

View File

@ -61,7 +61,7 @@ setup(
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.8',
python_requires='>=3.9',
install_requires=[
'ml_dtypes>=0.1.0',
'numpy>=1.21',

View File

@ -79,7 +79,6 @@ config.parse_flags_with_absl()
FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = jtu.numpy_version()
def _check_instance(self, x):

View File

@ -14,7 +14,6 @@
from functools import partial
import re
import sys
import unittest
import numpy as np
@ -35,8 +34,6 @@ config.parse_flags_with_absl()
FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
class DynamicShapeStagingTest(jtu.JaxTestCase):
def test_basic_staging(self):