Remove deprecated unsafe_raw_array method from PRNG keys

PiperOrigin-RevId: 595190146
This commit is contained in:
Jake VanderPlas 2024-01-02 13:02:46 -08:00 committed by jax authors
parent e6c890171b
commit fff5ea579a
2 changed files with 3 additions and 9 deletions

View File

@ -17,7 +17,7 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering. devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical This is a temporary state until we can create `Sharding`s without physical
devices. devices.
* Deprecations * Deprecations & Removals
* A number of previously deprecated functions have been removed, following a * A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`). standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes: This includes:
@ -34,6 +34,8 @@ Remember to align the itemized text with the first line of an item within a list
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`, * from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`. `trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`. * from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.
## jaxlib 0.4.24 ## jaxlib 0.4.24

View File

@ -19,7 +19,6 @@ from functools import partial, reduce
import math import math
import operator as op import operator as op
from typing import Any, Callable, NamedTuple from typing import Any, Callable, NamedTuple
import warnings
import numpy as np import numpy as np
@ -308,13 +307,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment] on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment] unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def unsafe_raw_array(self):
# deprecated on 13 Sept 2023
raise warnings.warn(
'The `unsafe_raw_array` method of PRNG key arrays is deprecated. '
'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2)
return self._base_array
def addressable_data(self, index: int) -> PRNGKeyArrayImpl: def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index)) return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))