Merge pull request #14094 from skye:version

PiperOrigin-RevId: 503482368
This commit is contained in:
jax authors 2023-01-20 11:11:38 -08:00
commit b846ce60c6
4 changed files with 16 additions and 8 deletions

View File

@ -6,7 +6,11 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
Remember to align the itemized text with the first line of an item within a list.
-->
## jax 0.4.2
## jax 0.4.3
## jaxlib 0.4.3
## jax 0.4.2 (Jan 20, 2023)
* Breaking changes
* Deleted `jax.experimental.callback`
@ -16,7 +20,11 @@ Remember to align the itemized text with the first line of an item within a list
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}`#13980`).
## jaxlib 0.4.2
## jaxlib 0.4.2 (Jan 20, 2023)
* Changes
* Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring
automatic device memory defragmentation.
## jax 0.4.1 (Dec 13, 2022)

View File

@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "9379bf81d201afa483bee61fe481fddd6809c05116529d7bbbc98566f9f10f83",
strip_prefix = "tensorflow-48cedda6908b7c31457dc7c8a297a62b2c273504",
sha256 = "dc7063605dc5281b6b99b1b0b2de250f55bd7fb959b27a86d2f91cae0528fd5d",
strip_prefix = "tensorflow-48f9d3dfcf148e1c2dfcf79b6334c2e2a2783093",
urls = [
"https://github.com/tensorflow/tensorflow/archive/48cedda6908b7c31457dc7c8a297a62b2c273504.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/48f9d3dfcf148e1c2dfcf79b6334c2e2a2783093.tar.gz",
],
)

View File

@ -22,7 +22,7 @@ from jax.config import config
TPU_DRIVER_MODE = 0
def setup_tpu(tpu_driver_version='tpu_driver_20221212'):
def setup_tpu(tpu_driver_version='tpu_driver_20230120'):
"""Sets up Colab to run on TPU.
Note: make sure the Colab Runtime is set to Accelerator: TPU.

View File

@ -19,14 +19,14 @@ import sys
from setuptools import setup, find_packages
_current_jaxlib_version = '0.4.1'
_current_jaxlib_version = '0.4.2'
# The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.1'
_available_cuda_versions = ['11']
_default_cuda_version = '11'
_available_cudnn_versions = ['82', '86']
_default_cudnn_version = '86'
_libtpu_version = '0.1.dev20221212'
_libtpu_version = '0.1.dev20230120'
_dct = {}
with open('jax/version.py', encoding='utf-8') as f: