Fix pytype errors.

PiperOrigin-RevId: 509984207
This commit is contained in:
Peter Hawkins 2023-02-15 18:11:55 -08:00 committed by jax authors
parent 37d4ad910a
commit 768960b4e4
3 changed files with 12 additions and 8 deletions

View File

@ -26,6 +26,7 @@ import collections
import functools
from functools import partial
import inspect
import typing
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
@ -2386,7 +2387,7 @@ def _cpp_pmap(
map_bind_continuation, top_trace, fun_, tracers, params = (
core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun,
*p.flat_args, **params))
execute: Optional[functools.partial] = None
execute: Optional[Callable] = None
if isinstance(top_trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
out = map_bind_continuation(execute(*tracers))
@ -2399,12 +2400,13 @@ def _cpp_pmap(
out = tree_unflatten(out_pytree_def, out_flat)
### Decide whether we can support the C++ fast path
use_fastpath = (
execute is not None and
# We don't support JAX extension backends.
isinstance(execute, pxla.ExecuteReplicated) and
use_fastpath = False
if execute is not None and isinstance(execute, pxla.ExecuteReplicated):
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
use_fastpath = (
# TODO(sharadmv): Enable effects in replicated computation
not execute.has_unordered_effects and not execute.has_host_callbacks and
not execute_replicated.has_unordered_effects
and not execute_replicated.has_host_callbacks and
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
all(
@ -2413,7 +2415,7 @@ def _cpp_pmap(
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
execute_replicated = execute
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
out_handler = execute_replicated.out_handler
in_handler = execute_replicated.in_handler
out_indices = [tuple(s.devices_indices_map(a.shape).values())

View File

@ -962,7 +962,7 @@ def xla_pmap_impl_lazy(
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
):
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars)
and not all(g is not None for g in global_arg_shapes)):

View File

@ -280,6 +280,8 @@ class XlaExecutable(Executable):
class XlaLowering(Lowering):
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
compile_args: Dict[str, Any]
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
raise NotImplementedError("must override")