mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split more targets out the main JAX Bazel target.
Namely: * abstract_arrays * ad_util * api_util * interpreters/partial_eval * lax_reference PiperOrigin-RevId: 520618715
This commit is contained in:
parent
3135fbcd7f
commit
47177e1417
77
jax/BUILD
77
jax/BUILD
@ -89,11 +89,8 @@ py_library_providing_imports_info(
|
||||
name = "jax",
|
||||
srcs = [
|
||||
"_src/__init__.py",
|
||||
"_src/abstract_arrays.py",
|
||||
"_src/ad_checkpoint.py",
|
||||
"_src/ad_util.py",
|
||||
"_src/api.py",
|
||||
"_src/api_util.py",
|
||||
"_src/array.py",
|
||||
"_src/callback.py",
|
||||
"_src/checkify.py",
|
||||
@ -104,7 +101,12 @@ py_library_providing_imports_info(
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/lax_reference.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
"_src/interpreters/ad.py",
|
||||
"_src/interpreters/batching.py",
|
||||
"_src/interpreters/mlir.py",
|
||||
"_src/interpreters/pxla.py",
|
||||
"_src/interpreters/xla.py",
|
||||
"_src/maps.py",
|
||||
"_src/pjit.py",
|
||||
"_src/prng.py",
|
||||
@ -117,7 +119,6 @@ py_library_providing_imports_info(
|
||||
"*.py",
|
||||
"_src/debugger/**/*.py",
|
||||
"_src/image/**/*.py",
|
||||
"_src/interpreters/**/*.py",
|
||||
"_src/lax/**/*.py",
|
||||
"_src/nn/**/*.py",
|
||||
"_src/numpy/**/*.py",
|
||||
@ -162,6 +163,9 @@ py_library_providing_imports_info(
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":abstract_arrays",
|
||||
":ad_util",
|
||||
":api_util",
|
||||
":basearray",
|
||||
":cloud_tpu_init",
|
||||
":config",
|
||||
@ -173,6 +177,7 @@ py_library_providing_imports_info(
|
||||
":lazy_loader",
|
||||
":mesh",
|
||||
":monitoring",
|
||||
":partial_eval",
|
||||
":path",
|
||||
":pretty_printer",
|
||||
":profiler",
|
||||
@ -185,7 +190,42 @@ py_library_providing_imports_info(
|
||||
":version",
|
||||
":xla_bridge",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + jax_extra_deps,
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "abstract_arrays",
|
||||
srcs = ["_src/abstract_arrays.py"],
|
||||
deps = [
|
||||
":ad_util",
|
||||
":core",
|
||||
":traceback_util",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "ad_util",
|
||||
srcs = ["_src/ad_util.py"],
|
||||
deps = [
|
||||
":core",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":typing",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "api_util",
|
||||
srcs = ["_src/api_util.py"],
|
||||
deps = [
|
||||
":abstract_arrays",
|
||||
":config",
|
||||
":core",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
@ -266,6 +306,16 @@ pytype_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "lax_reference",
|
||||
srcs = ["_src/lax_reference.py"],
|
||||
visibility = [":internal"] + jax_visibility("lax_reference"),
|
||||
deps = [
|
||||
":core",
|
||||
":util",
|
||||
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "lazy_loader",
|
||||
srcs = ["_src/lazy_loader.py"],
|
||||
@ -287,6 +337,21 @@ pytype_strict_library(
|
||||
srcs = ["_src/monitoring.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "partial_eval",
|
||||
srcs = ["_src/interpreters/partial_eval.py"],
|
||||
deps = [
|
||||
":api_util",
|
||||
":config",
|
||||
":core",
|
||||
":effects",
|
||||
":profiler",
|
||||
":source_info_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "path",
|
||||
srcs = ["_src/path.py"],
|
||||
|
@ -16,15 +16,14 @@ from __future__ import annotations
|
||||
import types
|
||||
from typing import Any, Callable, Dict, TypeVar, Union, cast
|
||||
|
||||
from jax.tree_util import register_pytree_node
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.core import (lattice_join, Primitive, valid_jaxtype,
|
||||
raise_to_shaped, get_aval)
|
||||
from jax._src.util import safe_map
|
||||
from jax._src.tree_util import register_pytree_node
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import safe_map
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
@ -27,7 +27,7 @@ from weakref import ref
|
||||
import numpy as np
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src.config import config
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
|
@ -410,7 +410,10 @@ jax_test(
|
||||
"tpu": 30,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
"//jax:lax_reference",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user