Merge pull request #17225 from jakevdp:prng-flag

PiperOrigin-RevId: 559243191
This commit is contained in:
jax authors 2023-08-22 15:41:10 -07:00
commit bf29c5e5f1
4 changed files with 46 additions and 2 deletions

View File

@ -62,6 +62,7 @@ from jax._src.config import (
numpy_dtype_promotion as numpy_dtype_promotion,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
legacy_prng_key as legacy_prng_key,
transfer_guard as transfer_guard,
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,

View File

@ -794,6 +794,13 @@ distributed_debug = config.define_bool_state(
'computations. Logging is performed with `logging` at WARNING '
'level.'))
legacy_prng_key = config.define_enum_state(
name='jax_legacy_prng_key',
enum_values=['allow', 'warn', 'error'],
default='allow',
help=('Specify the behavior when raw PRNG keys are passed to '
'jax.random APIs.')
)
enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',

View File

@ -72,12 +72,25 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
if isinstance(key, prng.PRNGKeyArray):
return key, False
elif _arraylike(key):
if config.jax_enable_custom_prng:
# Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
if config.jax_legacy_prng_key == 'error':
raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".')
elif config.jax_legacy_prng_key == 'warn':
warnings.warn(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
elif config.jax_enable_custom_prng:
# TODO(jakevdp): possibly remove this warning condition.
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.random_wrap(key, impl=default_prng_impl()), True
return wrapped_key, True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')

View File

@ -553,6 +553,29 @@ class PrngTest(jtu.JaxTestCase):
def f(seed): return make_key(seed)
jax.vjp(f, 1) # doesn't crash
def test_legacy_prng_key_flag(self):
raw_key = jnp.zeros(2, dtype='uint32')
invalid_key = jnp.zeros(1, dtype='float32')
msg = "Legacy uint32 key array passed as key to jax.random function."
with jax.legacy_prng_key('allow'):
# TODO(jakevdp): remove when enable_custom_prng no longer issues warnings
with jax.enable_custom_prng(False):
with self.assertNoWarnings():
random.uniform(raw_key)
with jax.legacy_prng_key('warn'):
with self.assertWarnsRegex(UserWarning, msg):
random.uniform(raw_key)
with jax.legacy_prng_key('error'):
with self.assertRaisesRegex(ValueError, msg):
random.uniform(raw_key)
# Invalid key error should take precedence.
with self.assertRaisesRegex(TypeError, "JAX encountered invalid PRNG key data"):
random.uniform(invalid_key)
class ThreefryPrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in [