Deprecate the contents of jax.prng

This commit is contained in:
Jake VanderPlas 2023-08-30 15:13:32 -07:00
parent 1761f7921b
commit 4b89d03147
3 changed files with 69 additions and 25 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.16
* Internal deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
## jaxlib 0.4.16
## jax 0.4.15 (Aug 30 2023)

View File

@ -15,12 +15,53 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.prng import (
PRNGImpl as PRNGImpl,
seed_with_impl as seed_with_impl,
threefry2x32_p as threefry2x32_p,
threefry_2x32 as threefry_2x32,
threefry_prng_impl as threefry_prng_impl,
rbg_prng_impl as rbg_prng_impl,
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
)
from jax._src import prng as _prng
_deprecations = {
# Added August 29, 2023
"PRNGImpl": (
"jax.prng.PRNGImpl is deprecated. Use jax.extend.random.PRNGImpl instead.",
_prng.PRNGImpl,
),
"seed_with_impl": (
"jax.prng.seed_with_impl is deprecated. Use jax.extend.random.seed_with_impl instead.",
_prng.seed_with_impl,
),
"threefry2x32_p": (
"jax.prng.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p instead.",
_prng.threefry2x32_p,
),
"threefry_2x32": (
"jax.prng.threefry_2x32 is deprecated. Use jax.extend.random.threefry_2x32 instead.",
_prng.threefry_2x32,
),
"threefry_prng_impl": (
"jax.prng.threefry_prng_impl is deprecated. Use jax.extend.random.threefry_prng_impl instead.",
_prng.threefry_prng_impl,
),
"rbg_prng_impl": (
"jax.prng.rbg_prng_impl is deprecated. Use jax.extend.random.rbg_prng_impl instead.",
_prng.rbg_prng_impl,
),
"unsafe_rbg_prng_impl": (
"jax.prng.unsafe_rbg_prng_impl is deprecated. Use jax.extend.random.unsafe_rbg_prng_impl instead.",
_prng.unsafe_rbg_prng_impl,
),
}
import typing
if typing.TYPE_CHECKING:
PRNGImpl = _prng.PRNGImpl
seed_with_impl = _prng.seed_with_impl
threefry2x32_p = _prng.threefry2x32_p
threefry_2x32 = _prng.threefry_2x32
threefry_prng_impl = _prng.threefry_prng_impl
rbg_prng_impl = _prng.rbg_prng_impl
unsafe_rbg_prng_impl = _prng.unsafe_rbg_prng_impl
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _prng

View File

@ -32,7 +32,6 @@ import jax
from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import prng
from jax import random
from jax import tree_util
from jax._src import core
@ -66,9 +65,9 @@ def _maybe_unwrap(key):
return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]
PRNG_IMPLS = [('threefry2x32', prng_internal.threefry_prng_impl),
('rbg', prng_internal.rbg_prng_impl),
('unsafe_rbg', prng_internal.unsafe_rbg_prng_impl)]
class OnX64(enum.Enum):
@ -222,24 +221,24 @@ class PrngTest(jtu.JaxTestCase):
return tuple(hex(x.copy()).rstrip("L") for x in result)
expected = ("0x6b200159", "0x99ba4efe")
result = prng.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
result = prng_internal.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
u32_max = np.iinfo(np.uint32).max
result = prng.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
result = prng_internal.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9c", "0x483df7a0")
result = prng.threefry_2x32(
result = prng_internal.threefry_2x32(
np.uint32([0x13198a2e, 0x03707344]),
np.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = prng.threefry_2x32(
result = prng_internal.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.concatenate([
jnp.full((n,), 0x243f6a88, jnp.uint32),
@ -251,7 +250,7 @@ class PrngTest(jtu.JaxTestCase):
def testThreefry2x32Empty(self):
# Regression test for an op-by-op crash for empty arrays in CUDA mode.
with jax.disable_jit():
result = prng.threefry_2x32(
result = prng_internal.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.ones((10, 0,), jnp.uint32))
np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
@ -260,7 +259,7 @@ class PrngTest(jtu.JaxTestCase):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
_ = prng.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
_ = prng_internal.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
finally:
xla.apply_primitive = apply_primitive
@ -524,17 +523,17 @@ class PrngTest(jtu.JaxTestCase):
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_threefry2x32_key(self):
self.check_key_has_impl(random.threefry2x32_key(42),
prng.threefry_prng_impl)
prng_internal.threefry_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_rbg_key(self):
self.check_key_has_impl(random.rbg_key(42),
prng.rbg_prng_impl)
prng_internal.rbg_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_unsafe_rbg_key(self):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng.unsafe_rbg_prng_impl)
prng_internal.unsafe_rbg_prng_impl)
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
@ -1809,7 +1808,7 @@ class KeyArrayTest(jtu.JaxTestCase):
# -- prng primitives
def test_random_wrap_vmap(self):
f = partial(prng_internal.random_wrap, impl=prng.threefry_prng_impl)
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
keys = jax.vmap(f, in_axes=0)(base_arr)
self.assertIsInstance(keys, random.KeyArray)
@ -2144,7 +2143,7 @@ def _double_threefry_fold_in(key, data):
return jnp.vstack([threefry_fold_in(key[0], data),
threefry_fold_in(key[1], data)])
double_threefry_prng_impl = prng.PRNGImpl(
double_threefry_prng_impl = prng_internal.PRNGImpl(
key_shape=(2, 2),
seed=_double_threefry_seed,
split=_double_threefry_split,
@ -2155,7 +2154,7 @@ double_threefry_prng_impl = prng.PRNGImpl(
@jtu.with_config(jax_default_prng_impl='threefry2x32')
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def make_key(self, seed):
return prng.seed_with_impl(double_threefry_prng_impl, seed)
return prng_internal.seed_with_impl(double_threefry_prng_impl, seed)
def test_split_shape(self):
key = self.make_key(73)