mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Deprecate the contents of jax.prng
This commit is contained in:
parent
1761f7921b
commit
4b89d03147
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.16
|
||||
|
||||
* Internal deprecations/removals:
|
||||
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
|
||||
{mod}`jax.extend.random`.
|
||||
|
||||
## jaxlib 0.4.16
|
||||
|
||||
## jax 0.4.15 (Aug 30 2023)
|
||||
|
59
jax/prng.py
59
jax/prng.py
@ -15,12 +15,53 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.prng import (
|
||||
PRNGImpl as PRNGImpl,
|
||||
seed_with_impl as seed_with_impl,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry_prng_impl as threefry_prng_impl,
|
||||
rbg_prng_impl as rbg_prng_impl,
|
||||
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
|
||||
)
|
||||
from jax._src import prng as _prng
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Added August 29, 2023
|
||||
"PRNGImpl": (
|
||||
"jax.prng.PRNGImpl is deprecated. Use jax.extend.random.PRNGImpl instead.",
|
||||
_prng.PRNGImpl,
|
||||
),
|
||||
"seed_with_impl": (
|
||||
"jax.prng.seed_with_impl is deprecated. Use jax.extend.random.seed_with_impl instead.",
|
||||
_prng.seed_with_impl,
|
||||
),
|
||||
"threefry2x32_p": (
|
||||
"jax.prng.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p instead.",
|
||||
_prng.threefry2x32_p,
|
||||
),
|
||||
"threefry_2x32": (
|
||||
"jax.prng.threefry_2x32 is deprecated. Use jax.extend.random.threefry_2x32 instead.",
|
||||
_prng.threefry_2x32,
|
||||
),
|
||||
"threefry_prng_impl": (
|
||||
"jax.prng.threefry_prng_impl is deprecated. Use jax.extend.random.threefry_prng_impl instead.",
|
||||
_prng.threefry_prng_impl,
|
||||
),
|
||||
"rbg_prng_impl": (
|
||||
"jax.prng.rbg_prng_impl is deprecated. Use jax.extend.random.rbg_prng_impl instead.",
|
||||
_prng.rbg_prng_impl,
|
||||
),
|
||||
"unsafe_rbg_prng_impl": (
|
||||
"jax.prng.unsafe_rbg_prng_impl is deprecated. Use jax.extend.random.unsafe_rbg_prng_impl instead.",
|
||||
_prng.unsafe_rbg_prng_impl,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
PRNGImpl = _prng.PRNGImpl
|
||||
seed_with_impl = _prng.seed_with_impl
|
||||
threefry2x32_p = _prng.threefry2x32_p
|
||||
threefry_2x32 = _prng.threefry_2x32
|
||||
threefry_prng_impl = _prng.threefry_prng_impl
|
||||
rbg_prng_impl = _prng.rbg_prng_impl
|
||||
unsafe_rbg_prng_impl = _prng.unsafe_rbg_prng_impl
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _prng
|
||||
|
@ -32,7 +32,6 @@ import jax
|
||||
from jax import grad
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import prng
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
@ -66,9 +65,9 @@ def _maybe_unwrap(key):
|
||||
return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key
|
||||
|
||||
|
||||
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
|
||||
('rbg', prng.rbg_prng_impl),
|
||||
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]
|
||||
PRNG_IMPLS = [('threefry2x32', prng_internal.threefry_prng_impl),
|
||||
('rbg', prng_internal.rbg_prng_impl),
|
||||
('unsafe_rbg', prng_internal.unsafe_rbg_prng_impl)]
|
||||
|
||||
|
||||
class OnX64(enum.Enum):
|
||||
@ -222,24 +221,24 @@ class PrngTest(jtu.JaxTestCase):
|
||||
return tuple(hex(x.copy()).rstrip("L") for x in result)
|
||||
|
||||
expected = ("0x6b200159", "0x99ba4efe")
|
||||
result = prng.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
|
||||
result = prng_internal.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
|
||||
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0x1cb996fc", "0xbb002be7")
|
||||
u32_max = np.iinfo(np.uint32).max
|
||||
result = prng.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
|
||||
result = prng_internal.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0xc4923a9c", "0x483df7a0")
|
||||
result = prng.threefry_2x32(
|
||||
result = prng_internal.threefry_2x32(
|
||||
np.uint32([0x13198a2e, 0x03707344]),
|
||||
np.uint32([0x243f6a88, 0x85a308d3]))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
def testThreefry2x32Large(self):
|
||||
n = 10000000
|
||||
result = prng.threefry_2x32(
|
||||
result = prng_internal.threefry_2x32(
|
||||
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
|
||||
jnp.concatenate([
|
||||
jnp.full((n,), 0x243f6a88, jnp.uint32),
|
||||
@ -251,7 +250,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
def testThreefry2x32Empty(self):
|
||||
# Regression test for an op-by-op crash for empty arrays in CUDA mode.
|
||||
with jax.disable_jit():
|
||||
result = prng.threefry_2x32(
|
||||
result = prng_internal.threefry_2x32(
|
||||
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
|
||||
jnp.ones((10, 0,), jnp.uint32))
|
||||
np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
|
||||
@ -260,7 +259,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
def fail(*args, **kwargs): assert False
|
||||
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
|
||||
try:
|
||||
_ = prng.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
|
||||
_ = prng_internal.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
|
||||
finally:
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
@ -524,17 +523,17 @@ class PrngTest(jtu.JaxTestCase):
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
self.check_key_has_impl(random.threefry2x32_key(42),
|
||||
prng.threefry_prng_impl)
|
||||
prng_internal.threefry_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_rbg_key(self):
|
||||
self.check_key_has_impl(random.rbg_key(42),
|
||||
prng.rbg_prng_impl)
|
||||
prng_internal.rbg_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_unsafe_rbg_key(self):
|
||||
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
||||
prng.unsafe_rbg_prng_impl)
|
||||
prng_internal.unsafe_rbg_prng_impl)
|
||||
|
||||
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
|
||||
for ctor in KEY_CTORS
|
||||
@ -1809,7 +1808,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
# -- prng primitives
|
||||
|
||||
def test_random_wrap_vmap(self):
|
||||
f = partial(prng_internal.random_wrap, impl=prng.threefry_prng_impl)
|
||||
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
|
||||
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
|
||||
keys = jax.vmap(f, in_axes=0)(base_arr)
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
@ -2144,7 +2143,7 @@ def _double_threefry_fold_in(key, data):
|
||||
return jnp.vstack([threefry_fold_in(key[0], data),
|
||||
threefry_fold_in(key[1], data)])
|
||||
|
||||
double_threefry_prng_impl = prng.PRNGImpl(
|
||||
double_threefry_prng_impl = prng_internal.PRNGImpl(
|
||||
key_shape=(2, 2),
|
||||
seed=_double_threefry_seed,
|
||||
split=_double_threefry_split,
|
||||
@ -2155,7 +2154,7 @@ double_threefry_prng_impl = prng.PRNGImpl(
|
||||
@jtu.with_config(jax_default_prng_impl='threefry2x32')
|
||||
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
def make_key(self, seed):
|
||||
return prng.seed_with_impl(double_threefry_prng_impl, seed)
|
||||
return prng_internal.seed_with_impl(double_threefry_prng_impl, seed)
|
||||
|
||||
def test_split_shape(self):
|
||||
key = self.make_key(73)
|
||||
|
Loading…
x
Reference in New Issue
Block a user