mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17594 from jakevdp:dep-prngkey
PiperOrigin-RevId: 565163390
This commit is contained in:
commit
11c2f167a4
@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
{mod}`jax.extend.random`.
|
||||
* The internal submodule path `jax.linear_util` has been deprecated. Use
|
||||
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
|
||||
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
|
||||
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
|
||||
detection of typed prng keys.
|
||||
|
||||
## jaxlib 0.4.16
|
||||
|
||||
|
@ -23,6 +23,7 @@ from typing import Any, Literal, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import random
|
||||
@ -32,7 +33,7 @@ from jax._src.util import set_module
|
||||
|
||||
export = set_module('jax.nn.initializers')
|
||||
|
||||
KeyArray = random.KeyArray
|
||||
KeyArray = jax.Array
|
||||
Array = Any
|
||||
# TODO: Import or define these to match
|
||||
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
|
||||
|
@ -1475,10 +1475,9 @@ def custom_numeric(
|
||||
|
||||
def custom_random_keys_output():
|
||||
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
||||
# TODO(frostig): Don't need this conditional once we always
|
||||
# enable_custom_prng. We can even assert the isinstance instead.
|
||||
# Here we handle both new-style and old-style keys; see JEP 9263
|
||||
def unwrap_keys(keys):
|
||||
if isinstance(keys, jax.random.KeyArray):
|
||||
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key):
|
||||
return jax._src.prng.random_unwrap(keys)
|
||||
else:
|
||||
return keys
|
||||
|
@ -130,23 +130,6 @@ For more about jax_threefry_partitionable, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
"""
|
||||
|
||||
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
|
||||
# TODO(frostig): remove this typechecking workaround. Our move away
|
||||
# from PRNGKeyArray as a pytree led to Python typechecker breakages in
|
||||
# several downstream annotations (e.g. annotations in jax-dependent
|
||||
# libraries that are violated by their callers). It may be that the
|
||||
# pytree registration decorator invalidated the checks. This will be
|
||||
# easier to handle after we always enable_custom_prng.
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
PRNGKeyArray = typing.Any
|
||||
KeyArray = typing.Any
|
||||
else:
|
||||
# TODO(frostig): replace with KeyArray from jax._src.random once we
|
||||
# always enable_custom_prng
|
||||
PRNGKeyArray = _PRNGKeyArray
|
||||
KeyArray = PRNGKeyArray
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
@ -201,3 +184,31 @@ from jax._src.random import (
|
||||
wald as wald,
|
||||
weibull_min as weibull_min,
|
||||
)
|
||||
|
||||
|
||||
# Deprecations
|
||||
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
|
||||
|
||||
_deprecations = {
|
||||
# Added September 13, 2023:
|
||||
"PRNGKeyArray": (
|
||||
"jax.random.PRNGKeyArray is deprecated. Use jax.Array for annotations, and "
|
||||
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
|
||||
"typed prng keys.", _PRNGKeyArray
|
||||
),
|
||||
"KeyArray": (
|
||||
"jax.random.KeyArray is deprecated. Use jax.Array for annotations, and "
|
||||
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
|
||||
"typed prng keys.", _PRNGKeyArray
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
PRNGKeyArray = typing.Any
|
||||
KeyArray = typing.Any
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -1208,7 +1208,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = pjit(make_keys, in_shardings=P(None), out_shardings=P(None))
|
||||
|
||||
out = f(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
@ -1841,7 +1841,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return make_key(seeds)
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
@ -1858,7 +1858,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return make_key(seeds)
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
@ -1875,7 +1875,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return make_key(seeds)
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
@ -2247,17 +2247,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
||||
y = jax.device_put(x, s)
|
||||
self.assertIsInstance(y, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(y.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(y.sharding, s)
|
||||
|
||||
s1 = SingleDeviceSharding(jax.devices()[1])
|
||||
z = jax.device_put(x, s1)
|
||||
self.assertIsInstance(z, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(z.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(z.sharding, s1)
|
||||
|
||||
out_p = jax.pmap(lambda x: x)(np.arange(jax.device_count()))
|
||||
a = jax.device_put(x, out_p.sharding)
|
||||
self.assertIsInstance(a, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(a.sharding, out_p.sharding)
|
||||
|
||||
op = xc.OpSharding()
|
||||
@ -2266,7 +2266,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
gs = GSPMDSharding(tuple(mesh.devices.flat), op)
|
||||
b = jax.device_put(x, gs)
|
||||
self.assertIsInstance(b, jax.random.KeyArray)
|
||||
self.assertTrue(jax.dtypes.issubdtype(b.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(b.sharding, gs)
|
||||
|
||||
def test_device_put_on_different_sharding(self):
|
||||
|
@ -1733,7 +1733,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_construction(self):
|
||||
key = random.key(42)
|
||||
self.assertIsInstance(key, random.PRNGKeyArray)
|
||||
self.assertIsInstance(key, jax_random.PRNGKeyArray)
|
||||
|
||||
def test_issubdtype(self):
|
||||
key = random.key(42)
|
||||
@ -1743,7 +1743,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_construction_upgrade_flag(self):
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIsInstance(key, random.PRNGKeyArray)
|
||||
self.assertIsInstance(key, jax_random.PRNGKeyArray)
|
||||
|
||||
def make_keys(self, *shape, seed=28):
|
||||
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
|
||||
@ -1797,13 +1797,13 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def test_isinstance(self):
|
||||
@jax.jit
|
||||
def f(k):
|
||||
self.assertIsInstance(k, random.KeyArray)
|
||||
self.assertIsInstance(k, jax_random.PRNGKeyArray)
|
||||
return k
|
||||
|
||||
k1 = self.make_keys()
|
||||
k2 = f(k1)
|
||||
self.assertIsInstance(k1, random.KeyArray)
|
||||
self.assertIsInstance(k2, random.KeyArray)
|
||||
self.assertIsInstance(k1, jax_random.PRNGKeyArray)
|
||||
self.assertIsInstance(k2, jax_random.PRNGKeyArray)
|
||||
|
||||
def test_cpp_dispatch_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling a jitted
|
||||
@ -1868,10 +1868,10 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
|
||||
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
|
||||
keys = jax.vmap(f, in_axes=0)(base_arr)
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
keys = jax.vmap(f, in_axes=1)(base_arr.T)
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
|
||||
@jtu.sample_product(use_internal=[False, True])
|
||||
@ -1968,20 +1968,20 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
ks = self.make_keys(3, 4)
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
|
||||
_, out = jax.jit(f)(ks) # doesn't crash
|
||||
self.assertIsInstance(out, random.KeyArray)
|
||||
self.assertIsInstance(out, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(out.shape, (3, 4))
|
||||
|
||||
def test_slice(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_dynamic_slice(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
index = np.int16(1) # non-default int type to catch type errors.
|
||||
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_dynamic_update_slice(self):
|
||||
@ -1989,51 +1989,51 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
k = self.make_keys(1, 4)
|
||||
index = np.int16(1) # non-default int type to catch type errors.
|
||||
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (3, 4))
|
||||
|
||||
def test_transpose(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x.T)(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (4, 3))
|
||||
|
||||
def test_gather(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (4,))
|
||||
|
||||
ks = self.make_keys(3, 4, 5)
|
||||
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (4, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (2,))
|
||||
|
||||
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (3, 2, 1))
|
||||
|
||||
def test_select(self):
|
||||
ks = self.make_keys(3, 2)
|
||||
cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2)
|
||||
ys = jax.jit(lax.select)(cs, ks, ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (3, 2))
|
||||
|
||||
def test_select_scalar_cond(self):
|
||||
# regression test for https://github.com/google/jax/issues/16422
|
||||
ks = self.make_keys(3)
|
||||
ys = lax.select(True, ks, ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
|
||||
self.assertEqual(ys.shape, (3,))
|
||||
|
||||
def test_vmap_of_cond(self):
|
||||
@ -2106,7 +2106,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
custom_result = jax.grad(f)(0.0, key)
|
||||
|
||||
self.assertAllClose(default_result, custom_result)
|
||||
self.assertIsInstance(key_dot, random.PRNGKeyArray)
|
||||
self.assertIsInstance(key_dot, jax_random.PRNGKeyArray)
|
||||
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))
|
||||
|
||||
def test_key_array_indexing_0d(self):
|
||||
@ -2373,7 +2373,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
def check_shape(self, func, *args):
|
||||
like = lambda keys: jnp.ones(keys.shape)
|
||||
out_key = func(*args)
|
||||
self.assertIsInstance(out_key, random.KeyArray)
|
||||
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
|
||||
out_like_key = func(*tree_util.tree_map(like, args))
|
||||
self.assertIsInstance(out_like_key, jax.Array)
|
||||
self.assertEqual(out_key.shape, out_like_key.shape)
|
||||
@ -2383,11 +2383,11 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out_arr, jax.Array)
|
||||
|
||||
out_key = key_func(*key_args)
|
||||
self.assertIsInstance(out_key, random.KeyArray)
|
||||
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
|
||||
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
||||
|
||||
out_key = jax.jit(key_func)(*key_args)
|
||||
self.assertIsInstance(out_key, random.KeyArray)
|
||||
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
|
||||
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
||||
|
||||
@parameterized.parameters([
|
||||
|
Loading…
x
Reference in New Issue
Block a user