add jax.extend.random.wrap_key_data

This commit is contained in:
Roy Frostig 2023-08-25 17:36:15 -07:00
parent 4e227a328a
commit a69f134cde
5 changed files with 71 additions and 0 deletions

View File

@ -153,6 +153,7 @@ py_library_providing_imports_info(
[
"*.py",
"_src/debugger/**/*.py",
"_src/extend/**/*.py",
"_src/image/**/*.py",
"_src/lax/**/*.py",
"_src/nn/**/*.py",

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

23
jax/_src/extend/random.py Normal file
View File

@ -0,0 +1,23 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from jax._src import prng
from jax._src import random
from jax._src.typing import Array
def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
impl_obj = random.resolve_prng_impl(impl)
return prng.random_wrap(key_bits_array, impl=impl_obj)

View File

@ -24,3 +24,7 @@ from jax._src.prng import (
rbg_prng_impl as rbg_prng_impl,
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
)
from jax._src.extend.random import (
wrap_key_data as wrap_key_data,
)

View File

@ -14,6 +14,7 @@
from absl.testing import absltest
import jax
import jax.extend as jex
from jax._src import prng
@ -22,6 +23,7 @@ from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
class ExtendTest(jtu.JaxTestCase):
def test_symbols(self):
# Assume these are tested in random_test.py, only check equivalence
@ -33,5 +35,33 @@ class ExtendTest(jtu.JaxTestCase):
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
class RandomTest(jtu.JaxTestCase):
def test_wrap_key_default(self):
key1 = jax.random.key(17)
data = jax.random.key_data(key1)
key2 = jex.random.wrap_key_data(data)
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))
impl = config.jax_default_prng_impl
key3 = jex.random.wrap_key_data(data, impl=impl)
self.assertEqual(key1.dtype, key3.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key3))
def test_wrap_key_explicit(self):
key1 = jax.random.key(17, impl='rbg')
data = jax.random.key_data(key1)
key2 = jex.random.wrap_key_data(data, impl='rbg')
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))
key3 = jex.random.wrap_key_data(data, impl='unsafe_rbg')
self.assertNotEqual(key1.dtype, key3.dtype)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())