mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
This commit is contained in:
commit
6aa20d8f8f
9
.github/workflows/ci-build.yaml
vendored
9
.github/workflows/ci-build.yaml
vendored
@ -50,18 +50,19 @@ jobs:
|
|||||||
num_generated_cases: 10
|
num_generated_cases: 10
|
||||||
use-latest-jaxlib: false
|
use-latest-jaxlib: false
|
||||||
- name-prefix: "with internal numpy/scipy"
|
- name-prefix: "with internal numpy/scipy"
|
||||||
python-version: 3.6
|
python-version: 3.7
|
||||||
os: ubuntu-latest
|
os: ubuntu-latest
|
||||||
enable-x64: 1
|
enable-x64: 1
|
||||||
# Test with numpy version that matches Google-internal version
|
# Test with numpy version that matches Google's internal version
|
||||||
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
package-overrides: "numpy==1.19.5 scipy==1.2.1"
|
||||||
num_generated_cases: 10
|
num_generated_cases: 10
|
||||||
use-latest-jaxlib: false
|
use-latest-jaxlib: false
|
||||||
- name-prefix: "with 3.7"
|
- name-prefix: "with 3.7"
|
||||||
python-version: 3.7
|
python-version: 3.7
|
||||||
os: ubuntu-latest
|
os: ubuntu-latest
|
||||||
enable-x64: 0
|
enable-x64: 0
|
||||||
package-overrides: "none"
|
# Test with the oldest legal NumPy version.
|
||||||
|
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
||||||
num_generated_cases: 8
|
num_generated_cases: 8
|
||||||
# Test against latest jaxlib
|
# Test against latest jaxlib
|
||||||
use-latest-jaxlib: true
|
use-latest-jaxlib: true
|
||||||
|
10
CHANGELOG.md
10
CHANGELOG.md
@ -11,12 +11,22 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
## jax 0.2.18 (unreleased)
|
## jax 0.2.18 (unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
|
||||||
|
|
||||||
|
* Breaking changes:
|
||||||
|
* Support for Python 3.6 has been dropped, per the
|
||||||
|
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||||
|
Please upgrade to a supported Python version.
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||||
({jax-issue}`#7196`)
|
({jax-issue}`#7196`)
|
||||||
|
|
||||||
## jaxlib 0.1.70 (unreleased)
|
## jaxlib 0.1.70 (unreleased)
|
||||||
|
* Breaking changes:
|
||||||
|
* Support for Python 3.6 has been dropped, per the
|
||||||
|
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||||
|
Please upgrade to a supported Python version.
|
||||||
|
|
||||||
|
|
||||||
## jaxlib 0.1.69 (July 9 2021)
|
## jaxlib 0.1.69 (July 9 2021)
|
||||||
* Fix bugs in TFRT CPU backend that results in incorrect results.
|
* Fix bugs in TFRT CPU backend that results in incorrect results.
|
||||||
|
@ -73,8 +73,8 @@ def get_python_version(python_bin_path):
|
|||||||
return major, minor
|
return major, minor
|
||||||
|
|
||||||
def check_python_version(python_version):
|
def check_python_version(python_version):
|
||||||
if python_version < (3, 6):
|
if python_version < (3, 7):
|
||||||
print("ERROR: JAX requires Python 3.6 or newer, found ", python_version)
|
print("ERROR: JAX requires Python 3.7 or newer, found ", python_version)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ set -xev
|
|||||||
|
|
||||||
source "$(dirname $(realpath $0))/build_jaxlib_wheels_helpers.sh"
|
source "$(dirname $(realpath $0))/build_jaxlib_wheels_helpers.sh"
|
||||||
|
|
||||||
PYTHON_VERSIONS="3.6.8 3.7.2 3.8.0 3.9.0"
|
PYTHON_VERSIONS="3.7.2 3.8.0 3.9.0"
|
||||||
CUDA_VERSIONS="10.1 10.2 11.0 11.1"
|
CUDA_VERSIONS="10.1 10.2 11.0 11.1"
|
||||||
CUDA_VARIANTS="cuda" # "cuda-included"
|
CUDA_VARIANTS="cuda" # "cuda-included"
|
||||||
|
|
||||||
|
@ -45,7 +45,6 @@ build_jax () {
|
|||||||
|
|
||||||
|
|
||||||
rm -fr dist
|
rm -fr dist
|
||||||
build_jax 3.6.8 1.17.3 1.2.0
|
|
||||||
build_jax 3.7.2 1.17.3 1.2.0
|
build_jax 3.7.2 1.17.3 1.2.0
|
||||||
build_jax 3.8.0 1.17.3 1.3.2
|
build_jax 3.8.0 1.17.3 1.3.2
|
||||||
build_jax 3.9.0 1.19.4 1.5.4
|
build_jax 3.9.0 1.19.4 1.5.4
|
||||||
|
@ -31,7 +31,7 @@ setup(
|
|||||||
author='JAX team',
|
author='JAX team',
|
||||||
author_email='jax-dev@google.com',
|
author_email='jax-dev@google.com',
|
||||||
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.7',
|
||||||
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||||
url='https://github.com/google/jax',
|
url='https://github.com/google/jax',
|
||||||
license='Apache-2.0',
|
license='Apache-2.0',
|
||||||
|
3
setup.py
3
setup.py
@ -34,7 +34,7 @@ setup(
|
|||||||
author_email='jax-dev@google.com',
|
author_email='jax-dev@google.com',
|
||||||
packages=find_packages(exclude=["examples"]),
|
packages=find_packages(exclude=["examples"]),
|
||||||
package_data={'jax': ['py.typed']},
|
package_data={'jax': ['py.typed']},
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.7',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'numpy>=1.17',
|
'numpy>=1.17',
|
||||||
'absl-py',
|
'absl-py',
|
||||||
@ -63,7 +63,6 @@ setup(
|
|||||||
url='https://github.com/google/jax',
|
url='https://github.com/google/jax',
|
||||||
license='Apache-2.0',
|
license='Apache-2.0',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3.6",
|
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
|
@ -167,7 +167,9 @@ class TestLBFGS(jtu.JaxTestCase):
|
|||||||
jax_res = min_op(x0)
|
jax_res = min_op(x0)
|
||||||
|
|
||||||
# Note that without bounds, L-BFGS-B is just L-BFGS
|
# Note that without bounds, L-BFGS-B is just L-BFGS
|
||||||
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
|
with jtu.ignore_warning(category=DeprecationWarning,
|
||||||
|
message=".*tostring.*is deprecated.*"):
|
||||||
|
scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x
|
||||||
|
|
||||||
if func.__name__ == 'matyas':
|
if func.__name__ == 'matyas':
|
||||||
# scipy performs badly for Matyas, compare to true minimum instead
|
# scipy performs badly for Matyas, compare to true minimum instead
|
||||||
|
Loading…
x
Reference in New Issue
Block a user