mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #17225 from jakevdp:prng-flag
PiperOrigin-RevId: 559243191
This commit is contained in:
commit
bf29c5e5f1
@ -62,6 +62,7 @@ from jax._src.config import (
|
||||
numpy_dtype_promotion as numpy_dtype_promotion,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
|
||||
legacy_prng_key as legacy_prng_key,
|
||||
transfer_guard as transfer_guard,
|
||||
transfer_guard_host_to_device as transfer_guard_host_to_device,
|
||||
transfer_guard_device_to_device as transfer_guard_device_to_device,
|
||||
|
@ -794,6 +794,13 @@ distributed_debug = config.define_bool_state(
|
||||
'computations. Logging is performed with `logging` at WARNING '
|
||||
'level.'))
|
||||
|
||||
legacy_prng_key = config.define_enum_state(
|
||||
name='jax_legacy_prng_key',
|
||||
enum_values=['allow', 'warn', 'error'],
|
||||
default='allow',
|
||||
help=('Specify the behavior when raw PRNG keys are passed to '
|
||||
'jax.random APIs.')
|
||||
)
|
||||
|
||||
enable_custom_prng = config.define_bool_state(
|
||||
name='jax_enable_custom_prng',
|
||||
|
@ -72,12 +72,25 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
|
||||
if isinstance(key, prng.PRNGKeyArray):
|
||||
return key, False
|
||||
elif _arraylike(key):
|
||||
if config.jax_enable_custom_prng:
|
||||
# Call random_wrap here to surface errors for invalid keys.
|
||||
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
|
||||
if config.jax_legacy_prng_key == 'error':
|
||||
raise ValueError(
|
||||
'Legacy uint32 key array passed as key to jax.random function. '
|
||||
'Please create keys using jax.random.key(). If use of a raw key array '
|
||||
'was intended, set jax_legacy_prng_key="allow".')
|
||||
elif config.jax_legacy_prng_key == 'warn':
|
||||
warnings.warn(
|
||||
'Legacy uint32 key array passed as key to jax.random function. '
|
||||
'Please create keys using jax.random.key(). If use of a raw key array '
|
||||
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
|
||||
elif config.jax_enable_custom_prng:
|
||||
# TODO(jakevdp): possibly remove this warning condition.
|
||||
warnings.warn(
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return prng.random_wrap(key, impl=default_prng_impl()), True
|
||||
return wrapped_key, True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
|
@ -553,6 +553,29 @@ class PrngTest(jtu.JaxTestCase):
|
||||
def f(seed): return make_key(seed)
|
||||
jax.vjp(f, 1) # doesn't crash
|
||||
|
||||
def test_legacy_prng_key_flag(self):
|
||||
raw_key = jnp.zeros(2, dtype='uint32')
|
||||
invalid_key = jnp.zeros(1, dtype='float32')
|
||||
msg = "Legacy uint32 key array passed as key to jax.random function."
|
||||
|
||||
with jax.legacy_prng_key('allow'):
|
||||
# TODO(jakevdp): remove when enable_custom_prng no longer issues warnings
|
||||
with jax.enable_custom_prng(False):
|
||||
with self.assertNoWarnings():
|
||||
random.uniform(raw_key)
|
||||
|
||||
with jax.legacy_prng_key('warn'):
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
random.uniform(raw_key)
|
||||
|
||||
with jax.legacy_prng_key('error'):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
random.uniform(raw_key)
|
||||
|
||||
# Invalid key error should take precedence.
|
||||
with self.assertRaisesRegex(TypeError, "JAX encountered invalid PRNG key data"):
|
||||
random.uniform(invalid_key)
|
||||
|
||||
|
||||
class ThreefryPrngTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user