Update jax version and changelog for 0.1.27.

Disable tfrt CPU backend on jaxlib 0.1.68 to work around https://github.com/google/jax/issues/7229.
This commit is contained in:
Peter Hawkins 2021-07-09 15:19:24 -04:00
parent a50dd080f6
commit b393d9a8c1
3 changed files with 13 additions and 5 deletions

View File

@ -8,8 +8,17 @@ Remember to align the itemized text with the first line of an item within a list
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.2.17 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
## jax 0.2.18 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
## jaxlib 0.1.69 (unreleased)
## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
problem.
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
* Reverse-mode autodiff functions ({func}`jax.grad`,
@ -20,7 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).
## jaxlib 0.1.69 (unreleased)
## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).

View File

@ -158,7 +158,7 @@ def register_backend_factory(name, factory, *, priority=0):
if jax.lib._xla_extension_version >= 23:
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
if jax.lib._xla_extension_version >= 24:
if jax.lib._xla_extension_version >= 27:
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),

View File

@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.16"
__version__ = "0.2.17"
_minimum_jaxlib_version = "0.1.65"