mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[random] deprecate named key creation functions
This commit is contained in:
parent
6a551a1efa
commit
22818d664f
@ -10,6 +10,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
* Deprecations
|
||||
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
|
||||
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
|
||||
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
|
||||
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
|
||||
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
||||
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
||||
|
||||
# jaxlib 0.4.17
|
||||
|
||||
|
@ -195,8 +195,8 @@ def _check_default_impl_with_no_custom_prng(impl, name):
|
||||
default_name = config.jax_default_prng_impl
|
||||
if not config.jax_enable_custom_prng and default_impl is not impl:
|
||||
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
||||
f'to seed an RNG with an implementation "f{name}" '
|
||||
f'differing from the default "f{default_name}".')
|
||||
f'to seed an RNG with an implementation "{name}" '
|
||||
f'differing from the default "{default_name}".')
|
||||
|
||||
def threefry2x32_key(seed: int) -> KeyArray:
|
||||
"""Creates a threefry2x32 PRNG key from an integer seed."""
|
||||
|
@ -170,17 +170,17 @@ from jax._src.random import (
|
||||
randint as randint,
|
||||
random_gamma_p as random_gamma_p,
|
||||
rayleigh as rayleigh,
|
||||
rbg_key as rbg_key,
|
||||
rbg_key as _deprecated_rbg_key,
|
||||
shuffle as shuffle,
|
||||
split as split,
|
||||
t as t,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry2x32_key as threefry2x32_key,
|
||||
threefry2x32_key as _deprecated_threefry2x32_key,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
triangular as triangular,
|
||||
truncated_normal as truncated_normal,
|
||||
uniform as uniform,
|
||||
unsafe_rbg_key as unsafe_rbg_key,
|
||||
unsafe_rbg_key as _deprecated_unsafe_rbg_key,
|
||||
wald as wald,
|
||||
weibull_min as weibull_min,
|
||||
wrap_key_data as wrap_key_data,
|
||||
@ -202,12 +202,25 @@ _deprecations = {
|
||||
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
|
||||
"typed prng keys.", _PRNGKeyArray
|
||||
),
|
||||
# Added September 21, 2023
|
||||
"threefry2x32_key": (
|
||||
"jax.random.threefry2x32_key(seed) is deprecated. "
|
||||
"Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key),
|
||||
"rbg_key": (
|
||||
"jax.random.rbg_key(seed) is deprecated. "
|
||||
"Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key),
|
||||
"unsafe_rbg_key": (
|
||||
"jax.random.unsafe_rbg_key(seed) is deprecated. "
|
||||
"Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
PRNGKeyArray = typing.Any
|
||||
KeyArray = typing.Any
|
||||
threefry_2x32_key = _deprecated_threefry2x32_key
|
||||
rbg_key = _deprecated_rbg_key
|
||||
unsafe_rbg_key = _deprecated_unsafe_rbg_key
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
@ -522,16 +522,19 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
with self.assertWarnsRegex(DeprecationWarning, "jax.random.threefry2x32_key"):
|
||||
self.check_key_has_impl(random.threefry2x32_key(42),
|
||||
prng_internal.threefry_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_rbg_key(self):
|
||||
with self.assertWarnsRegex(DeprecationWarning, "jax.random.rbg_key"):
|
||||
self.check_key_has_impl(random.rbg_key(42),
|
||||
prng_internal.rbg_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_unsafe_rbg_key(self):
|
||||
with self.assertWarnsRegex(DeprecationWarning, "jax.random.unsafe_rbg_key"):
|
||||
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
||||
prng_internal.unsafe_rbg_prng_impl)
|
||||
|
||||
@ -579,7 +582,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
class ThreefryPrngTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in [
|
||||
random.threefry2x32_key,
|
||||
jax_random.threefry2x32_key,
|
||||
partial(random.PRNGKey, impl='threefry2x32'),
|
||||
partial(random.key, impl='threefry2x32')]])
|
||||
def test_seed_no_implicit_transfers(self, make_key):
|
||||
@ -640,7 +643,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
f'{expected_freq}\n{actual_freq}')
|
||||
|
||||
def make_key(self, seed):
|
||||
return random.threefry2x32_key(seed)
|
||||
return random.PRNGKey(seed, impl='threefry2x32')
|
||||
|
||||
@jtu.sample_product(
|
||||
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
|
||||
@ -2296,7 +2299,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
@jtu.with_config(jax_default_prng_impl='rbg')
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def make_key(self, seed):
|
||||
return random.rbg_key(seed)
|
||||
return random.PRNGKey(seed, impl='rbg')
|
||||
|
||||
def test_split_shape(self):
|
||||
key = self.make_key(73)
|
||||
@ -2372,7 +2375,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
||||
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
def make_key(self, seed):
|
||||
return random.unsafe_rbg_key(seed)
|
||||
return random.PRNGKey(seed, impl="unsafe_rbg")
|
||||
|
||||
|
||||
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user