1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.

Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
This commit is contained in:
Peter Hawkins 2023-08-08 10:08:19 -07:00 committed by jax authors
parent b024e01440
commit afd56c15d9
5 changed files with 23 additions and 4 deletions

@ -23,6 +23,7 @@ Remember to align the itemized text with the first line of an item within a list
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
## jaxlib 0.4.15

@ -197,6 +197,7 @@ py_library_providing_imports_info(
":dtypes",
":effects",
":environment_info",
":jaxpr_util",
":lazy_loader",
":mesh",
":mlir",
@ -401,6 +402,17 @@ pytype_strict_library(
srcs = ["_src/lazy_loader.py"],
)
pytype_strict_library(
name = "jaxpr_util",
srcs = ["_src/jaxpr_util.py"],
deps = [
":core",
":source_info_util",
":util",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "mesh",
srcs = ["_src/mesh.py"],

@ -266,9 +266,14 @@ py_test(
] + py_deps("tensorflow_core"),
)
jax_test(
py_test(
name = "jaxpr_util_test",
srcs = ["jaxpr_util_test.py"],
deps = [
"//jax",
"//jax:jaxpr_util",
"//jax:test_util",
],
)
jax_test(

@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import os
import gzip
import json
from absl.testing import absltest
import jax
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
from jax import jit, make_jaxpr, numpy as jnp
from jax._src import jaxpr_util
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax import config