mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jex] replace extend.random.PRNGImpl
with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque object representing the defined implementation. This result can still be passed to `jax.random.key` and `wrap_key_data`. PiperOrigin-RevId: 578349699
This commit is contained in:
parent
49fedb1c52
commit
16d082b002
@ -28,7 +28,7 @@
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
PRNGImpl
|
||||
define_prng_impl
|
||||
seed_with_impl
|
||||
threefry2x32_p
|
||||
threefry_2x32
|
||||
|
34
jax/_src/extend/random.py
Normal file
34
jax/_src/extend/random.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Hashable
|
||||
|
||||
from jax import Array
|
||||
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
|
||||
def define_prng_impl(*,
|
||||
key_shape: Shape,
|
||||
seed: Callable[[Array], Array],
|
||||
split: Callable[[Array, Shape], Array],
|
||||
random_bits: Callable[[Array, int, Shape], Array],
|
||||
fold_in: Callable[[Array, int], Array],
|
||||
name: str = '<unnamed>',
|
||||
tag: str = '?') -> Hashable:
|
||||
return random.PRNGSpec(prng.PRNGImpl(
|
||||
key_shape, seed, split, random_bits, fold_in,
|
||||
name=name, tag=tag))
|
@ -145,8 +145,12 @@ class PRNGSpec:
|
||||
return self._impl == other._impl
|
||||
|
||||
|
||||
def resolve_prng_impl(
|
||||
impl_spec: Optional[Union[str, PRNGSpec, PRNGImpl]]) -> PRNGImpl:
|
||||
# TODO(frostig,vanderplas): remove PRNGImpl from this union when it's
|
||||
# no longer in the public API because `default_prng_impl` is gone
|
||||
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]
|
||||
|
||||
|
||||
def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
|
||||
if impl_spec is None:
|
||||
return default_prng_impl()
|
||||
if type(impl_spec) is PRNGImpl:
|
||||
@ -169,7 +173,8 @@ def resolve_prng_impl(
|
||||
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')
|
||||
|
||||
|
||||
def _key(ctor_name: str, seed: int | ArrayLike, impl_spec: Optional[str] ) -> KeyArray:
|
||||
def _key(ctor_name: str, seed: int | ArrayLike,
|
||||
impl_spec: Optional[PRNGSpecDesc]) -> KeyArray:
|
||||
impl = resolve_prng_impl(impl_spec)
|
||||
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
|
||||
raise TypeError(
|
||||
@ -180,7 +185,8 @@ def _key(ctor_name: str, seed: int | ArrayLike, impl_spec: Optional[str] ) -> Ke
|
||||
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
|
||||
return prng.random_seed(seed, impl=impl)
|
||||
|
||||
def key(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
|
||||
def key(seed: int | ArrayLike, *,
|
||||
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The result is a scalar array with a key that indicates the default PRNG
|
||||
@ -198,7 +204,8 @@ def key(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
|
||||
"""
|
||||
return _key('key', seed, impl)
|
||||
|
||||
def PRNGKey(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
|
||||
def PRNGKey(seed: int | ArrayLike, *,
|
||||
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The resulting key carries the default PRNG implementation, as
|
||||
@ -321,7 +328,8 @@ def key_data(keys: KeyArrayLike) -> Array:
|
||||
return _key_data(keys)
|
||||
|
||||
|
||||
def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
|
||||
def wrap_key_data(key_bits_array: Array, *,
|
||||
impl: Optional[PRNGSpecDesc] = None):
|
||||
"""Wrap an array of key data bits into a PRNG key array.
|
||||
|
||||
Args:
|
||||
|
@ -15,11 +15,11 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.extend.random import (
|
||||
define_prng_impl as define_prng_impl,
|
||||
)
|
||||
|
||||
from jax._src.prng import (
|
||||
# TODO(frostig,vanderplas): expose a define_prng_impl instead of the
|
||||
# PRNGImpl constructor, to leave some room for us to register or check input,
|
||||
# or to change what output type we return.
|
||||
PRNGImpl as PRNGImpl,
|
||||
random_seed as random_seed,
|
||||
seed_with_impl as seed_with_impl,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
|
@ -31,7 +31,6 @@ class ExtendTest(jtu.JaxTestCase):
|
||||
|
||||
def test_symbols(self):
|
||||
# Assume these are tested in random_test.py, only check equivalence
|
||||
self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl)
|
||||
self.assertIs(jex.random.seed_with_impl, prng.seed_with_impl)
|
||||
self.assertIs(jex.random.threefry2x32_p, prng.threefry2x32_p)
|
||||
self.assertIs(jex.random.threefry_2x32, prng.threefry_2x32)
|
||||
@ -61,21 +60,25 @@ class RandomTest(jtu.JaxTestCase):
|
||||
def no_rule(*args, **kwargs):
|
||||
assert False, 'unreachable'
|
||||
|
||||
impl = jex.random.PRNGImpl(shape, seed_rule, no_rule, no_rule, no_rule)
|
||||
impl = jex.random.define_prng_impl(
|
||||
key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule,
|
||||
random_bits=no_rule)
|
||||
k = jax.random.key(42, impl=impl)
|
||||
self.assertEqual(k.shape, ())
|
||||
self.assertEqual(impl, jax.random.key_impl(k)._impl)
|
||||
self.assertEqual(impl, jax.random.key_impl(k))
|
||||
|
||||
def test_key_wrap_with_custom_impl(self):
|
||||
def no_rule(*args, **kwargs):
|
||||
assert False, 'unreachable'
|
||||
|
||||
shape = (4, 2, 7)
|
||||
impl = jex.random.PRNGImpl(shape, no_rule, no_rule, no_rule, no_rule)
|
||||
impl = jex.random.define_prng_impl(
|
||||
key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule,
|
||||
random_bits=no_rule)
|
||||
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
|
||||
k = jax.random.wrap_key_data(data, impl=impl)
|
||||
self.assertEqual(k.shape, (3,))
|
||||
self.assertEqual(impl, jax.random.key_impl(k)._impl)
|
||||
self.assertEqual(impl, jax.random.key_impl(k))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user