diff --git a/jax/BUILD b/jax/BUILD index 0fca7b130..ae255add4 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -17,6 +17,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load( "//jaxlib:jax.bzl", + "jax_extend_internal_users", "jax_extra_deps", "jax_internal_packages", "jax_test_util_visibility", @@ -63,6 +64,15 @@ package_group( ] + jax_internal_packages, ) +package_group( + name = "jax_extend_users", + packages = [ + # Intentionally avoid jax dependencies on jax.extend. + # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + "//third_party/py/jax/tests/...", + ] + jax_extend_internal_users, +) + package_group( name = "mosaic_users", packages = [ @@ -872,6 +882,13 @@ pytype_library( ], ) +pytype_library( + name = "extend", + srcs = glob(["extend/**/*.py"]), + visibility = [":jax_extend_users"], + deps = [":jax"], +) + pytype_library( name = "mosaic", srcs = [ diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index e5e935d6c..994108b14 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -27,3 +27,7 @@ this package offers **no compatibility guarantee** across releases. Breaking changes will be announced via the `JAX project changelog `_. """ + +from jax.extend import ( + random as random, +) diff --git a/jax/extend/random.py b/jax/extend/random.py new file mode 100644 index 000000000..080fcdd26 --- /dev/null +++ b/jax/extend/random.py @@ -0,0 +1,26 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.prng import ( + PRNGImpl as PRNGImpl, + seed_with_impl as seed_with_impl, + threefry2x32_p as threefry2x32_p, + threefry_2x32 as threefry_2x32, + threefry_prng_impl as threefry_prng_impl, + rbg_prng_impl as rbg_prng_impl, + unsafe_rbg_prng_impl as unsafe_rbg_prng_impl, +) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index bd823b62b..f92bc9526 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -36,6 +36,7 @@ tf_exec_properties = _tf_exec_properties tf_cuda_tests_tags = _tf_cuda_tests_tags jax_internal_packages = [] +jax_extend_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] pallas_tpu_internal_users = [] diff --git a/setup.cfg b/setup.cfg index 34cdc9641..c864e1706 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ per-file-ignores = jax/dlpack.py:F401 jax/dtypes.py:F401 jax/errors.py:F401 + jax/extend/*.py:F401 jax/flatten_util.py:F401 jax/interpreters/ad.py:F401 jax/interpreters/batching.py:F401 diff --git a/tests/BUILD b/tests/BUILD index 9a9714061..1af03063a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -109,6 +109,12 @@ jax_test( ], ) +jax_test( + name = "extend_test", + srcs = ["extend_test.py"], + deps = ["//jax:extend"], +) + jax_test( name = "fft_test", srcs = ["fft_test.py"], diff --git a/tests/extend_test.py b/tests/extend_test.py new file mode 100644 index 000000000..2794973a6 --- /dev/null +++ b/tests/extend_test.py @@ -0,0 +1,37 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax.extend as jex + +from jax._src import prng +from jax._src import test_util as jtu + +from jax import config +config.parse_flags_with_absl() + +class ExtendTest(jtu.JaxTestCase): + def test_symbols(self): + # Assume these are tested in random_test.py, only check equivalence + self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl) + self.assertIs(jex.random.seed_with_impl, prng.seed_with_impl) + self.assertIs(jex.random.threefry2x32_p, prng.threefry2x32_p) + self.assertIs(jex.random.threefry_2x32, prng.threefry_2x32) + self.assertIs(jex.random.threefry_prng_impl, prng.threefry_prng_impl) + self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl) + self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())