mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move _src/tree_util.py into a separate Bazel target.
Fix a type error in api.py revealed by the split. PiperOrigin-RevId: 515745227
This commit is contained in:
parent
a722ec08f9
commit
a32a7ff903
13
jax/BUILD
13
jax/BUILD
@ -121,7 +121,6 @@ py_library_providing_imports_info(
|
||||
"_src/random.py",
|
||||
"_src/sharding.py",
|
||||
"_src/stages.py",
|
||||
"_src/tree_util.py",
|
||||
"_src/typing.py",
|
||||
] + glob(
|
||||
[
|
||||
@ -181,6 +180,7 @@ py_library_providing_imports_info(
|
||||
":profiler",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
":version",
|
||||
":xla_bridge",
|
||||
@ -278,6 +278,17 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "tree_util",
|
||||
srcs = ["_src/tree_util.py"],
|
||||
visibility = [":internal"] + jax_visibility("tree_util"),
|
||||
deps = [
|
||||
":traceback_util",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "traceback_util",
|
||||
srcs = ["_src/traceback_util.py"],
|
||||
|
@ -2106,7 +2106,7 @@ def pmap(
|
||||
class PmapCallInfo(NamedTuple):
|
||||
flat_fun: lu.WrappedFun
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
out_tree: Callable[[], PyTreeDef]
|
||||
flat_args: Sequence[Any]
|
||||
donated_invars: Sequence[bool]
|
||||
in_axes_flat: Sequence[Optional[int]]
|
||||
|
@ -25,7 +25,6 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.tree_util import PyTreeDef
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import pjit as pjit_lib
|
||||
@ -39,12 +38,11 @@ import numpy as np
|
||||
|
||||
# This needs to be top-level for the jax compilation cache.
|
||||
@functools.partial(jax.pmap, axis_name='hosts')
|
||||
def _psum(x: PyTreeDef) -> PyTreeDef:
|
||||
def _psum(x: Any) -> Any:
|
||||
return jax.lax.psum(x, 'hosts')
|
||||
|
||||
|
||||
def broadcast_one_to_all(in_tree: PyTreeDef,
|
||||
is_source: Optional[bool] = None) -> PyTreeDef:
|
||||
def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
|
||||
"""Broadcast data from a source host (host 0 by default) to all other hosts.
|
||||
|
||||
Args:
|
||||
@ -127,7 +125,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
|
||||
def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
|
||||
"""Gather data from across processes.
|
||||
|
||||
Args:
|
||||
|
@ -1057,6 +1057,9 @@ jax_test(
|
||||
"gpu": 10,
|
||||
"tpu": 10,
|
||||
},
|
||||
deps = [
|
||||
"//jax:tree_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user