diff --git a/CHANGELOG.md b/CHANGELOG.md index a4724c46d..8e36dcc01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,18 +8,25 @@ Remember to align the itemized text with the first line of an item within a list PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> -## jaxlib 0.3.6 (Unreleased) -## jax 0.3.7 (Unreleased) +## jax 0.3.7 (April 15, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.6...main). + commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7). +* Changes: + * Fixed a performance problem if the indices passed to + {func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`). + * {func}`jax.scipy.special.expit` and {func}`jax.scipy.special.logit` now + require their arguments to be scalars or JAX arrays. They also now promote + integer arguments to floating point. +## jaxlib 0.3.7 (April 15, 2202) ## jax 0.3.6 (April 12, 2022) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6). * Changes: - * Upgraded libtpu wheel to the fixed version. Fixes [#10218](https://github.com/google/jax/issues/10218). + * Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU + pod. Fixes [#10218](https://github.com/google/jax/issues/10218). * Deprecations: * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` for an alternative API. diff --git a/WORKSPACE b/WORKSPACE index c975cc843..2edc01f14 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # and update the sha256 with the result. http_archive( name = "org_tensorflow", - sha256 = "a491d6c2fac467956809d100fdeaeaada35103c724acebba1168f7cfd47f1209", - strip_prefix = "tensorflow-0d5668cbdc6b46d099bd3abd93374c09b2e8121f", + sha256 = "9f8bb53e42cd3994bbd2396065c4d8cd87602d596c0ac0fcfbb1bb40d6cc06cc", + strip_prefix = "tensorflow-e74ef072ecd54ca54f3940ce9b98af796ded2a1a", urls = [ - "https://github.com/tensorflow/tensorflow/archive/0d5668cbdc6b46d099bd3abd93374c09b2e8121f.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/e74ef072ecd54ca54f3940ce9b98af796ded2a1a.tar.gz", ], ) diff --git a/jaxlib/version.py b/jaxlib/version.py index c8fd70f51..202d6eddc 100644 --- a/jaxlib/version.py +++ b/jaxlib/version.py @@ -17,4 +17,4 @@ # reflect the most recent available binaries. # __version__ should be increased after releasing the current version # (i.e. on main, this is always the next version to be released). -__version__ = "0.3.6" +__version__ = "0.3.7" diff --git a/setup.py b/setup.py index 8cac7ff36..4d977626d 100644 --- a/setup.py +++ b/setup.py @@ -14,14 +14,14 @@ from setuptools import setup, find_packages -_current_jaxlib_version = '0.3.5' +_current_jaxlib_version = '0.3.7' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.3.5' _available_cuda_versions = ['11'] _default_cuda_version = '11' _available_cudnn_versions = ['82', '805'] _default_cudnn_version = '82' -_libtpu_version = '0.1.dev20220412' +_libtpu_version = '0.1.dev20220415' _dct = {} with open('jax/version.py') as f: