Finish jax and jaxlib 0.4.16 release

PiperOrigin-RevId: 566477931
This commit is contained in:
Yash Katariya 2023-09-18 19:08:43 -07:00 committed by jax authors
parent cf3fc55da3
commit dcc465b4de
3 changed files with 31 additions and 35 deletions

View File

@ -6,39 +6,11 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
Remember to align the itemized text with the first line of an item within a list. Remember to align the itemized text with the first line of an item within a list.
--> -->
## jax 0.4.16 # jax 0.4.17
* Deprecations # jaxlib 0.4.17
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
* Deprecations/removals: ## jax 0.4.16 (Sept 18, 2023)
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
## jaxlib 0.4.16
* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
## jax 0.4.15 (Aug 30 2023)
* Changes * Changes
* Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert * Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert
@ -89,8 +61,27 @@ Remember to align the itemized text with the first line of an item within a list
HLO lowering rules for custom JAX primitives have been deprecated. Custom HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in primitives should be defined using the StableHLO lowering utilities in
`jax.interpreters.mlir` instead. `jax.interpreters.mlir` instead.
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
* Internal deprecations/removals: * Deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype` * The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14). `jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
@ -100,13 +91,18 @@ Remember to align the itemized text with the first line of an item within a list
* The internal submodule path `jax.linear_util` has been deprecated. Use * The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`) {mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.15 (Aug 30 2023) ## jaxlib 0.4.16 (Sept 18, 2023)
* Changes: * Changes:
* Sparse CSR matrix multiplications via the experimental jax sparse APIs * Sparse CSR matrix multiplications via the experimental jax sparse APIs
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
made to improve compatibility with CUDA 12.2.1. made to improve compatibility with CUDA 12.2.1.
* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
## jax 0.4.14 (July 27, 2023) ## jax 0.4.14 (July 27, 2023)
* Changes * Changes

View File

@ -21,7 +21,7 @@ import os
import pathlib import pathlib
import subprocess import subprocess
_version = "0.4.16" _version = "0.4.17"
# The following line is overwritten by build scripts in distributions & # The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail. # releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None _release_version: str | None = None

View File

@ -24,7 +24,7 @@ project_name = 'jax'
_current_jaxlib_version = '0.4.16' _current_jaxlib_version = '0.4.16'
# The following should be updated with each new jaxlib release. # The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.14' _latest_jaxlib_version_on_pypi = '0.4.16'
_available_cuda11_cudnn_versions = ['86'] _available_cuda11_cudnn_versions = ['86']
_default_cuda11_cudnn_version = '86' _default_cuda11_cudnn_version = '86'
_default_cuda12_cudnn_version = '89' _default_cuda12_cudnn_version = '89'