mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[random] deprecate named key creation functions
This commit is contained in:
parent
6a551a1efa
commit
22818d664f
@ -10,6 +10,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
|
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
|
||||||
|
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
|
||||||
|
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
|
||||||
|
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
|
||||||
|
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
||||||
|
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
||||||
|
|
||||||
# jaxlib 0.4.17
|
# jaxlib 0.4.17
|
||||||
|
|
||||||
|
@ -195,8 +195,8 @@ def _check_default_impl_with_no_custom_prng(impl, name):
|
|||||||
default_name = config.jax_default_prng_impl
|
default_name = config.jax_default_prng_impl
|
||||||
if not config.jax_enable_custom_prng and default_impl is not impl:
|
if not config.jax_enable_custom_prng and default_impl is not impl:
|
||||||
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
||||||
f'to seed an RNG with an implementation "f{name}" '
|
f'to seed an RNG with an implementation "{name}" '
|
||||||
f'differing from the default "f{default_name}".')
|
f'differing from the default "{default_name}".')
|
||||||
|
|
||||||
def threefry2x32_key(seed: int) -> KeyArray:
|
def threefry2x32_key(seed: int) -> KeyArray:
|
||||||
"""Creates a threefry2x32 PRNG key from an integer seed."""
|
"""Creates a threefry2x32 PRNG key from an integer seed."""
|
||||||
|
@ -170,17 +170,17 @@ from jax._src.random import (
|
|||||||
randint as randint,
|
randint as randint,
|
||||||
random_gamma_p as random_gamma_p,
|
random_gamma_p as random_gamma_p,
|
||||||
rayleigh as rayleigh,
|
rayleigh as rayleigh,
|
||||||
rbg_key as rbg_key,
|
rbg_key as _deprecated_rbg_key,
|
||||||
shuffle as shuffle,
|
shuffle as shuffle,
|
||||||
split as split,
|
split as split,
|
||||||
t as t,
|
t as t,
|
||||||
threefry_2x32 as threefry_2x32,
|
threefry_2x32 as threefry_2x32,
|
||||||
threefry2x32_key as threefry2x32_key,
|
threefry2x32_key as _deprecated_threefry2x32_key,
|
||||||
threefry2x32_p as threefry2x32_p,
|
threefry2x32_p as threefry2x32_p,
|
||||||
triangular as triangular,
|
triangular as triangular,
|
||||||
truncated_normal as truncated_normal,
|
truncated_normal as truncated_normal,
|
||||||
uniform as uniform,
|
uniform as uniform,
|
||||||
unsafe_rbg_key as unsafe_rbg_key,
|
unsafe_rbg_key as _deprecated_unsafe_rbg_key,
|
||||||
wald as wald,
|
wald as wald,
|
||||||
weibull_min as weibull_min,
|
weibull_min as weibull_min,
|
||||||
wrap_key_data as wrap_key_data,
|
wrap_key_data as wrap_key_data,
|
||||||
@ -202,12 +202,25 @@ _deprecations = {
|
|||||||
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
|
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
|
||||||
"typed prng keys.", _PRNGKeyArray
|
"typed prng keys.", _PRNGKeyArray
|
||||||
),
|
),
|
||||||
|
# Added September 21, 2023
|
||||||
|
"threefry2x32_key": (
|
||||||
|
"jax.random.threefry2x32_key(seed) is deprecated. "
|
||||||
|
"Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key),
|
||||||
|
"rbg_key": (
|
||||||
|
"jax.random.rbg_key(seed) is deprecated. "
|
||||||
|
"Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key),
|
||||||
|
"unsafe_rbg_key": (
|
||||||
|
"jax.random.unsafe_rbg_key(seed) is deprecated. "
|
||||||
|
"Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key),
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
PRNGKeyArray = typing.Any
|
PRNGKeyArray = typing.Any
|
||||||
KeyArray = typing.Any
|
KeyArray = typing.Any
|
||||||
|
threefry_2x32_key = _deprecated_threefry2x32_key
|
||||||
|
rbg_key = _deprecated_rbg_key
|
||||||
|
unsafe_rbg_key = _deprecated_unsafe_rbg_key
|
||||||
else:
|
else:
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
|
@ -522,18 +522,21 @@ class PrngTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||||
def test_explicit_threefry2x32_key(self):
|
def test_explicit_threefry2x32_key(self):
|
||||||
self.check_key_has_impl(random.threefry2x32_key(42),
|
with self.assertWarnsRegex(DeprecationWarning, "jax.random.threefry2x32_key"):
|
||||||
prng_internal.threefry_prng_impl)
|
self.check_key_has_impl(random.threefry2x32_key(42),
|
||||||
|
prng_internal.threefry_prng_impl)
|
||||||
|
|
||||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||||
def test_explicit_rbg_key(self):
|
def test_explicit_rbg_key(self):
|
||||||
self.check_key_has_impl(random.rbg_key(42),
|
with self.assertWarnsRegex(DeprecationWarning, "jax.random.rbg_key"):
|
||||||
prng_internal.rbg_prng_impl)
|
self.check_key_has_impl(random.rbg_key(42),
|
||||||
|
prng_internal.rbg_prng_impl)
|
||||||
|
|
||||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||||
def test_explicit_unsafe_rbg_key(self):
|
def test_explicit_unsafe_rbg_key(self):
|
||||||
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
with self.assertWarnsRegex(DeprecationWarning, "jax.random.unsafe_rbg_key"):
|
||||||
prng_internal.unsafe_rbg_prng_impl)
|
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
||||||
|
prng_internal.unsafe_rbg_prng_impl)
|
||||||
|
|
||||||
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
|
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
|
||||||
for ctor in KEY_CTORS
|
for ctor in KEY_CTORS
|
||||||
@ -579,7 +582,7 @@ class PrngTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
class ThreefryPrngTest(jtu.JaxTestCase):
|
class ThreefryPrngTest(jtu.JaxTestCase):
|
||||||
@parameterized.parameters([{'make_key': ctor} for ctor in [
|
@parameterized.parameters([{'make_key': ctor} for ctor in [
|
||||||
random.threefry2x32_key,
|
jax_random.threefry2x32_key,
|
||||||
partial(random.PRNGKey, impl='threefry2x32'),
|
partial(random.PRNGKey, impl='threefry2x32'),
|
||||||
partial(random.key, impl='threefry2x32')]])
|
partial(random.key, impl='threefry2x32')]])
|
||||||
def test_seed_no_implicit_transfers(self, make_key):
|
def test_seed_no_implicit_transfers(self, make_key):
|
||||||
@ -640,7 +643,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
|||||||
f'{expected_freq}\n{actual_freq}')
|
f'{expected_freq}\n{actual_freq}')
|
||||||
|
|
||||||
def make_key(self, seed):
|
def make_key(self, seed):
|
||||||
return random.threefry2x32_key(seed)
|
return random.PRNGKey(seed, impl='threefry2x32')
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
|
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
|
||||||
@ -2296,7 +2299,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
|||||||
@jtu.with_config(jax_default_prng_impl='rbg')
|
@jtu.with_config(jax_default_prng_impl='rbg')
|
||||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||||
def make_key(self, seed):
|
def make_key(self, seed):
|
||||||
return random.rbg_key(seed)
|
return random.PRNGKey(seed, impl='rbg')
|
||||||
|
|
||||||
def test_split_shape(self):
|
def test_split_shape(self):
|
||||||
key = self.make_key(73)
|
key = self.make_key(73)
|
||||||
@ -2372,7 +2375,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
|||||||
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
||||||
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||||
def make_key(self, seed):
|
def make_key(self, seed):
|
||||||
return random.unsafe_rbg_key(seed)
|
return random.PRNGKey(seed, impl="unsafe_rbg")
|
||||||
|
|
||||||
|
|
||||||
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user