mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix pytype errors.
PiperOrigin-RevId: 509984207
This commit is contained in:
parent
37d4ad910a
commit
768960b4e4
@ -26,6 +26,7 @@ import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import typing
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
|
||||
|
||||
@ -2386,7 +2387,7 @@ def _cpp_pmap(
|
||||
map_bind_continuation, top_trace, fun_, tracers, params = (
|
||||
core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun,
|
||||
*p.flat_args, **params))
|
||||
execute: Optional[functools.partial] = None
|
||||
execute: Optional[Callable] = None
|
||||
if isinstance(top_trace, core.EvalTrace):
|
||||
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
|
||||
out = map_bind_continuation(execute(*tracers))
|
||||
@ -2399,12 +2400,13 @@ def _cpp_pmap(
|
||||
out = tree_unflatten(out_pytree_def, out_flat)
|
||||
|
||||
### Decide whether we can support the C++ fast path
|
||||
use_fastpath = (
|
||||
execute is not None and
|
||||
# We don't support JAX extension backends.
|
||||
isinstance(execute, pxla.ExecuteReplicated) and
|
||||
use_fastpath = False
|
||||
if execute is not None and isinstance(execute, pxla.ExecuteReplicated):
|
||||
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
|
||||
use_fastpath = (
|
||||
# TODO(sharadmv): Enable effects in replicated computation
|
||||
not execute.has_unordered_effects and not execute.has_host_callbacks and
|
||||
not execute_replicated.has_unordered_effects
|
||||
and not execute_replicated.has_host_callbacks and
|
||||
# No tracers in the outputs. Checking for ShardedDeviceArray should be
|
||||
# sufficient, but we use the more general `DeviceArray`.
|
||||
all(
|
||||
@ -2413,7 +2415,7 @@ def _cpp_pmap(
|
||||
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
execute_replicated = execute
|
||||
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
|
||||
out_handler = execute_replicated.out_handler
|
||||
in_handler = execute_replicated.in_handler
|
||||
out_indices = [tuple(s.devices_indices_map(a.shape).values())
|
||||
|
@ -962,7 +962,7 @@ def xla_pmap_impl_lazy(
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
):
|
||||
) -> Callable:
|
||||
if (config.jax_disable_jit and config.jax_eager_pmap and
|
||||
not is_explicit_global_axis_size and not any(d for d in donated_invars)
|
||||
and not all(g is not None for g in global_arg_shapes)):
|
||||
|
@ -280,6 +280,8 @@ class XlaExecutable(Executable):
|
||||
class XlaLowering(Lowering):
|
||||
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
|
||||
|
||||
compile_args: Dict[str, Any]
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
"""Return an HLO representation of this computation."""
|
||||
raise NotImplementedError("must override")
|
||||
|
Loading…
x
Reference in New Issue
Block a user