mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add jax.extend.random.wrap_key_data
This commit is contained in:
parent
4e227a328a
commit
a69f134cde
@ -153,6 +153,7 @@ py_library_providing_imports_info(
|
||||
[
|
||||
"*.py",
|
||||
"_src/debugger/**/*.py",
|
||||
"_src/extend/**/*.py",
|
||||
"_src/image/**/*.py",
|
||||
"_src/lax/**/*.py",
|
||||
"_src/nn/**/*.py",
|
||||
|
13
jax/_src/extend/__init__.py
Normal file
13
jax/_src/extend/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
23
jax/_src/extend/random.py
Normal file
23
jax/_src/extend/random.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src.typing import Array
|
||||
|
||||
def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
|
||||
impl_obj = random.resolve_prng_impl(impl)
|
||||
return prng.random_wrap(key_bits_array, impl=impl_obj)
|
@ -24,3 +24,7 @@ from jax._src.prng import (
|
||||
rbg_prng_impl as rbg_prng_impl,
|
||||
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
|
||||
)
|
||||
|
||||
from jax._src.extend.random import (
|
||||
wrap_key_data as wrap_key_data,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax._src import prng
|
||||
@ -22,6 +23,7 @@ from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ExtendTest(jtu.JaxTestCase):
|
||||
def test_symbols(self):
|
||||
# Assume these are tested in random_test.py, only check equivalence
|
||||
@ -33,5 +35,33 @@ class ExtendTest(jtu.JaxTestCase):
|
||||
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
|
||||
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
|
||||
class RandomTest(jtu.JaxTestCase):
|
||||
def test_wrap_key_default(self):
|
||||
key1 = jax.random.key(17)
|
||||
data = jax.random.key_data(key1)
|
||||
key2 = jex.random.wrap_key_data(data)
|
||||
self.assertEqual(key1.dtype, key2.dtype)
|
||||
self.assertArraysEqual(jax.random.key_data(key1),
|
||||
jax.random.key_data(key2))
|
||||
|
||||
impl = config.jax_default_prng_impl
|
||||
key3 = jex.random.wrap_key_data(data, impl=impl)
|
||||
self.assertEqual(key1.dtype, key3.dtype)
|
||||
self.assertArraysEqual(jax.random.key_data(key1),
|
||||
jax.random.key_data(key3))
|
||||
|
||||
def test_wrap_key_explicit(self):
|
||||
key1 = jax.random.key(17, impl='rbg')
|
||||
data = jax.random.key_data(key1)
|
||||
key2 = jex.random.wrap_key_data(data, impl='rbg')
|
||||
self.assertEqual(key1.dtype, key2.dtype)
|
||||
self.assertArraysEqual(jax.random.key_data(key1),
|
||||
jax.random.key_data(key2))
|
||||
|
||||
key3 = jex.random.wrap_key_data(data, impl='unsafe_rbg')
|
||||
self.assertNotEqual(key1.dtype, key3.dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user