diff --git a/jax/BUILD b/jax/BUILD index 8a86a923f..20cd8abaf 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -121,7 +121,6 @@ py_library_providing_imports_info( "_src/random.py", "_src/sharding.py", "_src/stages.py", - "_src/tree_util.py", "_src/typing.py", ] + glob( [ @@ -181,6 +180,7 @@ py_library_providing_imports_info( ":profiler", ":source_info_util", ":traceback_util", + ":tree_util", ":util", ":version", ":xla_bridge", @@ -278,6 +278,17 @@ pytype_library( ], ) +pytype_library( + name = "tree_util", + srcs = ["_src/tree_util.py"], + visibility = [":internal"] + jax_visibility("tree_util"), + deps = [ + ":traceback_util", + ":util", + "//jax/_src/lib", + ], +) + pytype_library( name = "traceback_util", srcs = ["_src/traceback_util.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index 594e104d0..efbd53f59 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2106,7 +2106,7 @@ def pmap( class PmapCallInfo(NamedTuple): flat_fun: lu.WrappedFun in_tree: PyTreeDef - out_tree: PyTreeDef + out_tree: Callable[[], PyTreeDef] flat_args: Sequence[Any] donated_invars: Sequence[bool] in_axes_flat: Sequence[Optional[int]] diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 0a467056b..78d6c5d66 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -25,7 +25,6 @@ from jax._src import core from jax._src import dispatch from jax._src import array from jax._src import sharding -from jax.tree_util import PyTreeDef from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import pjit as pjit_lib @@ -39,12 +38,11 @@ import numpy as np # This needs to be top-level for the jax compilation cache. @functools.partial(jax.pmap, axis_name='hosts') -def _psum(x: PyTreeDef) -> PyTreeDef: +def _psum(x: Any) -> Any: return jax.lax.psum(x, 'hosts') -def broadcast_one_to_all(in_tree: PyTreeDef, - is_source: Optional[bool] = None) -> PyTreeDef: +def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any: """Broadcast data from a source host (host 0 by default) to all other hosts. Args: @@ -127,7 +125,7 @@ def _handle_array_process_allgather(inp, tiled): return np.asarray(out.addressable_data(0)) -def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef: +def process_allgather(in_tree: Any, tiled: bool = False) -> Any: """Gather data from across processes. Args: diff --git a/tests/BUILD b/tests/BUILD index 0694cc563..72ec2f0c7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1057,6 +1057,9 @@ jax_test( "gpu": 10, "tpu": 10, }, + deps = [ + "//jax:tree_util", + ], ) jax_test(