Move _src/tree_util.py into a separate Bazel target.

Fix a type error in api.py revealed by the split.

PiperOrigin-RevId: 515745227
This commit is contained in:
Peter Hawkins 2023-03-10 14:51:08 -08:00 committed by jax authors
parent a722ec08f9
commit a32a7ff903
4 changed files with 19 additions and 7 deletions

View File

@ -121,7 +121,6 @@ py_library_providing_imports_info(
"_src/random.py",
"_src/sharding.py",
"_src/stages.py",
"_src/tree_util.py",
"_src/typing.py",
] + glob(
[
@ -181,6 +180,7 @@ py_library_providing_imports_info(
":profiler",
":source_info_util",
":traceback_util",
":tree_util",
":util",
":version",
":xla_bridge",
@ -278,6 +278,17 @@ pytype_library(
],
)
pytype_library(
name = "tree_util",
srcs = ["_src/tree_util.py"],
visibility = [":internal"] + jax_visibility("tree_util"),
deps = [
":traceback_util",
":util",
"//jax/_src/lib",
],
)
pytype_library(
name = "traceback_util",
srcs = ["_src/traceback_util.py"],

View File

@ -2106,7 +2106,7 @@ def pmap(
class PmapCallInfo(NamedTuple):
flat_fun: lu.WrappedFun
in_tree: PyTreeDef
out_tree: PyTreeDef
out_tree: Callable[[], PyTreeDef]
flat_args: Sequence[Any]
donated_invars: Sequence[bool]
in_axes_flat: Sequence[Optional[int]]

View File

@ -25,7 +25,6 @@ from jax._src import core
from jax._src import dispatch
from jax._src import array
from jax._src import sharding
from jax.tree_util import PyTreeDef
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import pjit as pjit_lib
@ -39,12 +38,11 @@ import numpy as np
# This needs to be top-level for the jax compilation cache.
@functools.partial(jax.pmap, axis_name='hosts')
def _psum(x: PyTreeDef) -> PyTreeDef:
def _psum(x: Any) -> Any:
return jax.lax.psum(x, 'hosts')
def broadcast_one_to_all(in_tree: PyTreeDef,
is_source: Optional[bool] = None) -> PyTreeDef:
def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
"""Broadcast data from a source host (host 0 by default) to all other hosts.
Args:
@ -127,7 +125,7 @@ def _handle_array_process_allgather(inp, tiled):
return np.asarray(out.addressable_data(0))
def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
"""Gather data from across processes.
Args:

View File

@ -1057,6 +1057,9 @@ jax_test(
"gpu": 10,
"tpu": 10,
},
deps = [
"//jax:tree_util",
],
)
jax_test(