Bump jax and jaxlib versions for 0.3.2 release

Also add CPU pjit to changelog
This commit is contained in:
Skye Wanderman-Milne 2022-03-16 14:25:19 -07:00
parent 2c23c947a5
commit d7087abce6
4 changed files with 4 additions and 3 deletions

View File

@ -27,6 +27,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar
or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`).
* The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
* `pjit` now works on CPU (in addition to previous TPU and GPU support).
## jaxlib 0.3.2 (March 16, 2022)

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.2"
__version__ = "0.3.3"
_minimum_jaxlib_version = "0.3.0"

View File

@ -17,4 +17,4 @@
# reflect the most recent available binaries.
# __version__ should be increased after releasing the current version
# (i.e. on main, this is always the next version to be released).
__version__ = "0.3.2"
__version__ = "0.3.3"

View File

@ -16,7 +16,7 @@ from setuptools import setup, find_packages
_current_jaxlib_version = '0.3.2'
# The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.3.0'
_latest_jaxlib_version_on_pypi = '0.3.2'
_available_cuda_versions = ['11']
_default_cuda_version = '11'
_available_cudnn_versions = ['82', '805']