mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration. PiperOrigin-RevId: 554861284
This commit is contained in:
parent
b024e01440
commit
afd56c15d9
@ -23,6 +23,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
for details and for mechanisms to override the default.
|
||||
* The option `--jax_coordination_service` has been removed. It is now always
|
||||
`True`.
|
||||
* `jax.jaxpr_util` has been removed from the public JAX namespace.
|
||||
|
||||
## jaxlib 0.4.15
|
||||
|
||||
|
12
jax/BUILD
12
jax/BUILD
@ -197,6 +197,7 @@ py_library_providing_imports_info(
|
||||
":dtypes",
|
||||
":effects",
|
||||
":environment_info",
|
||||
":jaxpr_util",
|
||||
":lazy_loader",
|
||||
":mesh",
|
||||
":mlir",
|
||||
@ -401,6 +402,17 @@ pytype_strict_library(
|
||||
srcs = ["_src/lazy_loader.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "jaxpr_util",
|
||||
srcs = ["_src/jaxpr_util.py"],
|
||||
deps = [
|
||||
":core",
|
||||
":source_info_util",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "mesh",
|
||||
srcs = ["_src/mesh.py"],
|
||||
|
@ -266,9 +266,14 @@ py_test(
|
||||
] + py_deps("tensorflow_core"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
py_test(
|
||||
name = "jaxpr_util_test",
|
||||
srcs = ["jaxpr_util_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:jaxpr_util",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -12,14 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import json
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
|
||||
from jax import jit, make_jaxpr, numpy as jnp
|
||||
from jax._src import jaxpr_util
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
|
Loading…
x
Reference in New Issue
Block a user