From b393d9a8c19469b5c49b68ab4734aa3f07ad46d1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 9 Jul 2021 15:19:24 -0400 Subject: [PATCH] Update jax version and changelog for 0.1.27. Disable tfrt CPU backend on jaxlib 0.1.68 to work around https://github.com/google/jax/issues/7229. --- CHANGELOG.md | 14 +++++++++++--- jax/lib/xla_bridge.py | 2 +- jax/version.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8edd7a8b6..d16866c95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,17 @@ Remember to align the itemized text with the first line of an item within a list PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> -## jax 0.2.17 (unreleased) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main). +## jax 0.2.18 (unreleased) +* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main). + +## jaxlib 0.1.69 (unreleased) + +## jax 0.2.17 (July 9 2021) +* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17). +* Bug fixes: + * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 + to work around #7229, which caused wrong outputs on CPU due to a concurrency + problem. * New features: * New SciPy function {py:func}`jax.scipy.special.sph_harm`. * Reverse-mode autodiff functions ({func}`jax.grad`, @@ -20,7 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. non-per-example way inside maps (initially only {func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`). -## jaxlib 0.1.69 (unreleased) ## jax 0.2.16 (June 23 2021) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16). diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index ce26d37df..305ce7816 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -158,7 +158,7 @@ def register_backend_factory(name, factory, *, priority=0): if jax.lib._xla_extension_version >= 23: register_backend_factory('interpreter', xla_client.make_interpreter_client, priority=-100) - if jax.lib._xla_extension_version >= 24: + if jax.lib._xla_extension_version >= 27: if FLAGS.jax_cpu_backend_variant == 'stream_executor': register_backend_factory('cpu', partial(xla_client.make_cpu_client, use_tfrt=False), diff --git a/jax/version.py b/jax/version.py index 5ae418731..195533a78 100644 --- a/jax/version.py +++ b/jax/version.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.16" +__version__ = "0.2.17" _minimum_jaxlib_version = "0.1.65"