diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e246017d..f798c0f16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`). * The standard jax[tpu] install can now be used with Cloud TPU v4 VMs. + * `pjit` now works on CPU (in addition to previous TPU and GPU support). ## jaxlib 0.3.2 (March 16, 2022) diff --git a/jax/version.py b/jax/version.py index a2ff0009a..1c68a4f67 100644 --- a/jax/version.py +++ b/jax/version.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.2" +__version__ = "0.3.3" _minimum_jaxlib_version = "0.3.0" diff --git a/jaxlib/version.py b/jaxlib/version.py index f27b31442..6d073ebfd 100644 --- a/jaxlib/version.py +++ b/jaxlib/version.py @@ -17,4 +17,4 @@ # reflect the most recent available binaries. # __version__ should be increased after releasing the current version # (i.e. on main, this is always the next version to be released). -__version__ = "0.3.2" +__version__ = "0.3.3" diff --git a/setup.py b/setup.py index 9cda00478..5a7fb96e7 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from setuptools import setup, find_packages _current_jaxlib_version = '0.3.2' # The following should be updated with each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.3.0' +_latest_jaxlib_version_on_pypi = '0.3.2' _available_cuda_versions = ['11'] _default_cuda_version = '11' _available_cudnn_versions = ['82', '805']