Remove deprecated unsafe_raw_array method from PRNG keys

PiperOrigin-RevId: 595190146
This commit is contained in:
Jake VanderPlas 2024-01-02 13:02:46 -08:00 committed by jax authors
parent e6c890171b
commit fff5ea579a
2 changed files with 3 additions and 9 deletions

View File

@ -17,7 +17,7 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* Deprecations
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes:
@ -34,6 +34,8 @@ Remember to align the itemized text with the first line of an item within a list
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.
## jaxlib 0.4.24

View File

@ -19,7 +19,6 @@ from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, NamedTuple
import warnings
import numpy as np
@ -308,13 +307,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def unsafe_raw_array(self):
# deprecated on 13 Sept 2023
raise warnings.warn(
'The `unsafe_raw_array` method of PRNG key arrays is deprecated. '
'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2)
return self._base_array
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))