Copybara import of the project:

--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
This commit is contained in:
Yash Katariya 2023-05-02 15:32:57 -07:00 committed by jax authors
parent 519a96305b
commit 6506ee2a40
5 changed files with 8 additions and 4 deletions

View File

@ -74,8 +74,8 @@ def get_python_version(python_bin_path):
return major, minor
def check_python_version(python_version):
if python_version < (3, 9):
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
if python_version < (3, 8):
print("ERROR: JAX requires Python 3.8 or newer, found ", python_version)
sys.exit(-1)

View File

@ -45,7 +45,7 @@ setup(
author='JAX team',
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.9',
python_requires='>=3.8',
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'],
url='https://github.com/google/jax',
license='Apache-2.0',

View File

@ -61,7 +61,7 @@ setup(
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.9',
python_requires='>=3.8',
install_requires=[
'ml_dtypes>=0.0.3',
'numpy>=1.21',

View File

@ -80,6 +80,7 @@ config.parse_flags_with_absl()
FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = jtu.numpy_version()
def _check_instance(self, x):

View File

@ -14,6 +14,7 @@
from functools import partial
import re
import sys
import unittest
import numpy as np
@ -34,6 +35,8 @@ config.parse_flags_with_absl()
FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
class DynamicShapeStagingTest(jtu.JaxTestCase):
def test_basic_staging(self):