From f9775a2cedff6078038a82f11a57de6ea405a32e Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 16 Mar 2022 10:17:42 -0700 Subject: [PATCH] Update CHANGELOG and setup.py for jax + jaxlib 0.3.2 releases --- CHANGELOG.md | 13 ++++++++++--- setup.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf13212c3..2e246017d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,15 @@ Remember to align the itemized text with the first line of an item within a list PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> -## jax 0.3.2 (Unreleased) +## jax 0.3.3 (Unreleased) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.1...main). + commits](https://github.com/google/jax/compare/jax-v0.3.2...main). + +## jaxlib 0.3.3 (Unreleased) + +## jax 0.3.2 (March 16, 2022) +* [GitHub + commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2). * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use @@ -20,9 +26,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. optimized alternatives to `jax.lax.top_k`. * {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`). + * The standard jax[tpu] install can now be used with Cloud TPU v4 VMs. -## jaxlib 0.3.1 (Unreleased) +## jaxlib 0.3.2 (March 16, 2022) * Changes * ``XlaComputation.as_hlo_text()`` now supports printing large constants by passing boolean flag ``print_large_constants=True``. diff --git a/setup.py b/setup.py index 7508dc8ba..9cda00478 100644 --- a/setup.py +++ b/setup.py @@ -14,14 +14,14 @@ from setuptools import setup, find_packages -_current_jaxlib_version = '0.3.0' +_current_jaxlib_version = '0.3.2' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.3.0' _available_cuda_versions = ['11'] _default_cuda_version = '11' _available_cudnn_versions = ['82', '805'] _default_cudnn_version = '82' -_libtpu_version = '0.1.dev20220128' +_libtpu_version = '0.1.dev20220315' _dct = {} with open('jax/version.py') as f: