Bump jax and jaxlib versions for 0.3.2 release

Also add CPU pjit to changelog
This commit is contained in:
Skye Wanderman-Milne 2022-03-16 14:25:19 -07:00
parent 2c23c947a5
commit d7087abce6
4 changed files with 4 additions and 3 deletions

View File

@ -27,6 +27,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar * {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar
or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`). or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`).
* The standard jax[tpu] install can now be used with Cloud TPU v4 VMs. * The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
* `pjit` now works on CPU (in addition to previous TPU and GPU support).
## jaxlib 0.3.2 (March 16, 2022) ## jaxlib 0.3.2 (March 16, 2022)

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.3.2" __version__ = "0.3.3"
_minimum_jaxlib_version = "0.3.0" _minimum_jaxlib_version = "0.3.0"

View File

@ -17,4 +17,4 @@
# reflect the most recent available binaries. # reflect the most recent available binaries.
# __version__ should be increased after releasing the current version # __version__ should be increased after releasing the current version
# (i.e. on main, this is always the next version to be released). # (i.e. on main, this is always the next version to be released).
__version__ = "0.3.2" __version__ = "0.3.3"

View File

@ -16,7 +16,7 @@ from setuptools import setup, find_packages
_current_jaxlib_version = '0.3.2' _current_jaxlib_version = '0.3.2'
# The following should be updated with each new jaxlib release. # The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.3.0' _latest_jaxlib_version_on_pypi = '0.3.2'
_available_cuda_versions = ['11'] _available_cuda_versions = ['11']
_default_cuda_version = '11' _default_cuda_version = '11'
_available_cudnn_versions = ['82', '805'] _available_cudnn_versions = ['82', '805']