create jax.extend.random as a copy of jax.prng

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
This commit is contained in:
Roy Frostig 2023-08-24 14:40:10 -07:00 committed by jax authors
parent 48921a1b31
commit a71c0e6ecc
7 changed files with 92 additions and 0 deletions

View File

@ -17,6 +17,7 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load(
"//jaxlib:jax.bzl",
"jax_extend_internal_users",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
@ -63,6 +64,15 @@ package_group(
] + jax_internal_packages,
)
package_group(
name = "jax_extend_users",
packages = [
# Intentionally avoid jax dependencies on jax.extend.
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
"//third_party/py/jax/tests/...",
] + jax_extend_internal_users,
)
package_group(
name = "mosaic_users",
packages = [
@ -872,6 +882,13 @@ pytype_library(
],
)
pytype_library(
name = "extend",
srcs = glob(["extend/**/*.py"]),
visibility = [":jax_extend_users"],
deps = [":jax"],
)
pytype_library(
name = "mosaic",
srcs = [

View File

@ -27,3 +27,7 @@ this package offers **no compatibility guarantee** across releases.
Breaking changes will be announced via the
`JAX project changelog <https://jax.readthedocs.io/en/latest/changelog.html>`_.
"""
from jax.extend import (
random as random,
)

26
jax/extend/random.py Normal file
View File

@ -0,0 +1,26 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.prng import (
PRNGImpl as PRNGImpl,
seed_with_impl as seed_with_impl,
threefry2x32_p as threefry2x32_p,
threefry_2x32 as threefry_2x32,
threefry_prng_impl as threefry_prng_impl,
rbg_prng_impl as rbg_prng_impl,
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
)

View File

@ -36,6 +36,7 @@ tf_exec_properties = _tf_exec_properties
tf_cuda_tests_tags = _tf_cuda_tests_tags
jax_internal_packages = []
jax_extend_internal_users = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
pallas_tpu_internal_users = []

View File

@ -34,6 +34,7 @@ per-file-ignores =
jax/dlpack.py:F401
jax/dtypes.py:F401
jax/errors.py:F401
jax/extend/*.py:F401
jax/flatten_util.py:F401
jax/interpreters/ad.py:F401
jax/interpreters/batching.py:F401

View File

@ -109,6 +109,12 @@ jax_test(
],
)
jax_test(
name = "extend_test",
srcs = ["extend_test.py"],
deps = ["//jax:extend"],
)
jax_test(
name = "fft_test",
srcs = ["fft_test.py"],

37
tests/extend_test.py Normal file
View File

@ -0,0 +1,37 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax.extend as jex
from jax._src import prng
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
class ExtendTest(jtu.JaxTestCase):
def test_symbols(self):
# Assume these are tested in random_test.py, only check equivalence
self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl)
self.assertIs(jex.random.seed_with_impl, prng.seed_with_impl)
self.assertIs(jex.random.threefry2x32_p, prng.threefry2x32_p)
self.assertIs(jex.random.threefry_2x32, prng.threefry_2x32)
self.assertIs(jex.random.threefry_prng_impl, prng.threefry_prng_impl)
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())