diff --git a/CHANGELOG.md b/CHANGELOG.md index 78ac98c9d..48c2435cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ Remember to align the itemized text with the first line of an item within a list devices to create `Sharding`s during lowering. This is a temporary state until we can create `Sharding`s without physical devices. -* Deprecations +* Deprecations & Removals * A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}`api-compatibility`). This includes: @@ -34,6 +34,8 @@ Remember to align the itemized text with the first line of an item within a list * from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`, `trapz`, and `in1d`. * from {mod}`jax.scipy.linalg`: `tril` and `triu`. + * The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been + removed. Use {func}`jax.random.key_data` instead. ## jaxlib 0.4.24 diff --git a/jax/_src/prng.py b/jax/_src/prng.py index e6af55718..df76b9d5f 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -19,7 +19,6 @@ from functools import partial, reduce import math import operator as op from typing import Any, Callable, NamedTuple -import warnings import numpy as np @@ -308,13 +307,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray): on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment] unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment] - def unsafe_raw_array(self): - # deprecated on 13 Sept 2023 - raise warnings.warn( - 'The `unsafe_raw_array` method of PRNG key arrays is deprecated. ' - 'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2) - return self._base_array - def addressable_data(self, index: int) -> PRNGKeyArrayImpl: return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))