Merge pull request #7294 from hawkinsp:py36

PiperOrigin-RevId: 384994957
This commit is contained in:
jax authors 2021-07-15 13:19:23 -07:00
commit 6aa20d8f8f
8 changed files with 23 additions and 12 deletions

View File

@ -50,18 +50,19 @@ jobs:
num_generated_cases: 10
use-latest-jaxlib: false
- name-prefix: "with internal numpy/scipy"
python-version: 3.6
python-version: 3.7
os: ubuntu-latest
enable-x64: 1
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.17.5 scipy==1.2.1"
# Test with numpy version that matches Google's internal version
package-overrides: "numpy==1.19.5 scipy==1.2.1"
num_generated_cases: 10
use-latest-jaxlib: false
- name-prefix: "with 3.7"
python-version: 3.7
os: ubuntu-latest
enable-x64: 0
package-overrides: "none"
# Test with the oldest legal NumPy version.
package-overrides: "numpy==1.17.5 scipy==1.2.1"
num_generated_cases: 8
# Test against latest jaxlib
use-latest-jaxlib: true

View File

@ -11,12 +11,22 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.18 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
## jaxlib 0.1.70 (unreleased)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
## jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.

View File

@ -73,8 +73,8 @@ def get_python_version(python_bin_path):
return major, minor
def check_python_version(python_version):
if python_version < (3, 6):
print("ERROR: JAX requires Python 3.6 or newer, found ", python_version)
if python_version < (3, 7):
print("ERROR: JAX requires Python 3.7 or newer, found ", python_version)
sys.exit(-1)

View File

@ -3,7 +3,7 @@ set -xev
source "$(dirname $(realpath $0))/build_jaxlib_wheels_helpers.sh"
PYTHON_VERSIONS="3.6.8 3.7.2 3.8.0 3.9.0"
PYTHON_VERSIONS="3.7.2 3.8.0 3.9.0"
CUDA_VERSIONS="10.1 10.2 11.0 11.1"
CUDA_VARIANTS="cuda" # "cuda-included"

View File

@ -45,7 +45,6 @@ build_jax () {
rm -fr dist
build_jax 3.6.8 1.17.3 1.2.0
build_jax 3.7.2 1.17.3 1.2.0
build_jax 3.8.0 1.17.3 1.3.2
build_jax 3.9.0 1.19.4 1.5.4

View File

@ -31,7 +31,7 @@ setup(
author='JAX team',
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.6',
python_requires='>=3.7',
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',

View File

@ -34,7 +34,7 @@ setup(
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples"]),
package_data={'jax': ['py.typed']},
python_requires='>=3.6',
python_requires='>=3.7',
install_requires=[
'numpy>=1.17',
'absl-py',
@ -63,7 +63,6 @@ setup(
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",

View File

@ -167,7 +167,9 @@ class TestLBFGS(jtu.JaxTestCase):
jax_res = min_op(x0)
# Note that without bounds, L-BFGS-B is just L-BFGS
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
with jtu.ignore_warning(category=DeprecationWarning,
message=".*tostring.*is deprecated.*"):
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
if func.__name__ == 'matyas':
# scipy performs badly for Matyas, compare to true minimum instead