Merge pull request #17594 from jakevdp:dep-prngkey

PiperOrigin-RevId: 565163390
This commit is contained in:
jax authors 2023-09-13 14:33:56 -07:00
commit 11c2f167a4
6 changed files with 66 additions and 52 deletions

View File

@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
detection of typed prng keys.
## jaxlib 0.4.16

View File

@ -23,6 +23,7 @@ from typing import Any, Literal, Protocol, Union
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax import random
@ -32,7 +33,7 @@ from jax._src.util import set_module
export = set_module('jax.nn.initializers')
KeyArray = random.KeyArray
KeyArray = jax.Array
Array = Any
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.

View File

@ -1475,10 +1475,9 @@ def custom_numeric(
def custom_random_keys_output():
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
# TODO(frostig): Don't need this conditional once we always
# enable_custom_prng. We can even assert the isinstance instead.
# Here we handle both new-style and old-style keys; see JEP 9263
def unwrap_keys(keys):
if isinstance(keys, jax.random.KeyArray):
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key):
return jax._src.prng.random_unwrap(keys)
else:
return keys

View File

@ -130,23 +130,6 @@ For more about jax_threefry_partitionable, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
# TODO(frostig): remove this typechecking workaround. Our move away
# from PRNGKeyArray as a pytree led to Python typechecker breakages in
# several downstream annotations (e.g. annotations in jax-dependent
# libraries that are violated by their callers). It may be that the
# pytree registration decorator invalidated the checks. This will be
# easier to handle after we always enable_custom_prng.
import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
else:
# TODO(frostig): replace with KeyArray from jax._src.random once we
# always enable_custom_prng
PRNGKeyArray = _PRNGKeyArray
KeyArray = PRNGKeyArray
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
@ -201,3 +184,31 @@ from jax._src.random import (
wald as wald,
weibull_min as weibull_min,
)
# Deprecations
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
_deprecations = {
# Added September 13, 2023:
"PRNGKeyArray": (
"jax.random.PRNGKeyArray is deprecated. Use jax.Array for annotations, and "
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
"KeyArray": (
"jax.random.KeyArray is deprecated. Use jax.Array for annotations, and "
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
}
import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -1208,7 +1208,7 @@ class PJitTest(jtu.BufferDonationTestCase):
f = pjit(make_keys, in_shardings=P(None), out_shardings=P(None))
out = f(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
@ -1841,7 +1841,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return make_key(seeds)
out = make_keys(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
@ -1858,7 +1858,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return make_key(seeds)
out = make_keys(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
@ -1875,7 +1875,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return make_key(seeds)
out = make_keys(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
@ -2247,17 +2247,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
y = jax.device_put(x, s)
self.assertIsInstance(y, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(y.dtype, jax.dtypes.prng_key))
self.assertEqual(y.sharding, s)
s1 = SingleDeviceSharding(jax.devices()[1])
z = jax.device_put(x, s1)
self.assertIsInstance(z, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(z.dtype, jax.dtypes.prng_key))
self.assertEqual(z.sharding, s1)
out_p = jax.pmap(lambda x: x)(np.arange(jax.device_count()))
a = jax.device_put(x, out_p.sharding)
self.assertIsInstance(a, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key))
self.assertEqual(a.sharding, out_p.sharding)
op = xc.OpSharding()
@ -2266,7 +2266,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
gs = GSPMDSharding(tuple(mesh.devices.flat), op)
b = jax.device_put(x, gs)
self.assertIsInstance(b, jax.random.KeyArray)
self.assertTrue(jax.dtypes.issubdtype(b.dtype, jax.dtypes.prng_key))
self.assertEqual(b.sharding, gs)
def test_device_put_on_different_sharding(self):

View File

@ -1733,7 +1733,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_construction(self):
key = random.key(42)
self.assertIsInstance(key, random.PRNGKeyArray)
self.assertIsInstance(key, jax_random.PRNGKeyArray)
def test_issubdtype(self):
key = random.key(42)
@ -1743,7 +1743,7 @@ class KeyArrayTest(jtu.JaxTestCase):
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
self.assertIsInstance(key, random.PRNGKeyArray)
self.assertIsInstance(key, jax_random.PRNGKeyArray)
def make_keys(self, *shape, seed=28):
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
@ -1797,13 +1797,13 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_isinstance(self):
@jax.jit
def f(k):
self.assertIsInstance(k, random.KeyArray)
self.assertIsInstance(k, jax_random.PRNGKeyArray)
return k
k1 = self.make_keys()
k2 = f(k1)
self.assertIsInstance(k1, random.KeyArray)
self.assertIsInstance(k2, random.KeyArray)
self.assertIsInstance(k1, jax_random.PRNGKeyArray)
self.assertIsInstance(k2, jax_random.PRNGKeyArray)
def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
@ -1868,10 +1868,10 @@ class KeyArrayTest(jtu.JaxTestCase):
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
keys = jax.vmap(f, in_axes=0)(base_arr)
self.assertIsInstance(keys, random.KeyArray)
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))
keys = jax.vmap(f, in_axes=1)(base_arr.T)
self.assertIsInstance(keys, random.KeyArray)
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))
@jtu.sample_product(use_internal=[False, True])
@ -1968,20 +1968,20 @@ class KeyArrayTest(jtu.JaxTestCase):
ks = self.make_keys(3, 4)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, random.KeyArray)
self.assertIsInstance(out, jax_random.PRNGKeyArray)
self.assertEqual(out.shape, (3, 4))
def test_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_slice(self):
ks = self.make_keys(3, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_update_slice(self):
@ -1989,51 +1989,51 @@ class KeyArrayTest(jtu.JaxTestCase):
k = self.make_keys(1, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 4))
def test_transpose(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 3))
def test_gather(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4,))
ks = self.make_keys(3, 4, 5)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 5))
ys = jax.jit(lambda x: x[1, 2:4])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 5))
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2,))
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2, 1))
def test_select(self):
ks = self.make_keys(3, 2)
cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2)
ys = jax.jit(lax.select)(cs, ks, ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2))
def test_select_scalar_cond(self):
# regression test for https://github.com/google/jax/issues/16422
ks = self.make_keys(3)
ys = lax.select(True, ks, ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3,))
def test_vmap_of_cond(self):
@ -2106,7 +2106,7 @@ class KeyArrayTest(jtu.JaxTestCase):
custom_result = jax.grad(f)(0.0, key)
self.assertAllClose(default_result, custom_result)
self.assertIsInstance(key_dot, random.PRNGKeyArray)
self.assertIsInstance(key_dot, jax_random.PRNGKeyArray)
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))
def test_key_array_indexing_0d(self):
@ -2373,7 +2373,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
def check_shape(self, func, *args):
like = lambda keys: jnp.ones(keys.shape)
out_key = func(*args)
self.assertIsInstance(out_key, random.KeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
out_like_key = func(*tree_util.tree_map(like, args))
self.assertIsInstance(out_like_key, jax.Array)
self.assertEqual(out_key.shape, out_like_key.shape)
@ -2383,11 +2383,11 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertIsInstance(out_arr, jax.Array)
out_key = key_func(*key_args)
self.assertIsInstance(out_key, random.KeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
out_key = jax.jit(key_func)(*key_args)
self.assertIsInstance(out_key, random.KeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
@parameterized.parameters([