From 768960b4e47b356a4396c03acb4e0c5dacb1850d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 15 Feb 2023 18:11:55 -0800 Subject: [PATCH] Fix pytype errors. PiperOrigin-RevId: 509984207 --- jax/_src/api.py | 16 +++++++++------- jax/_src/interpreters/pxla.py | 2 +- jax/_src/stages.py | 2 ++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index bb65aeafa..3d656c395 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -26,6 +26,7 @@ import collections import functools from functools import partial import inspect +import typing from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload) @@ -2386,7 +2387,7 @@ def _cpp_pmap( map_bind_continuation, top_trace, fun_, tracers, params = ( core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun, *p.flat_args, **params)) - execute: Optional[functools.partial] = None + execute: Optional[Callable] = None if isinstance(top_trace, core.EvalTrace): execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params) out = map_bind_continuation(execute(*tracers)) @@ -2399,12 +2400,13 @@ def _cpp_pmap( out = tree_unflatten(out_pytree_def, out_flat) ### Decide whether we can support the C++ fast path - use_fastpath = ( - execute is not None and - # We don't support JAX extension backends. - isinstance(execute, pxla.ExecuteReplicated) and + use_fastpath = False + if execute is not None and isinstance(execute, pxla.ExecuteReplicated): + execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + use_fastpath = ( # TODO(sharadmv): Enable effects in replicated computation - not execute.has_unordered_effects and not execute.has_host_callbacks and + not execute_replicated.has_unordered_effects + and not execute_replicated.has_host_callbacks and # No tracers in the outputs. Checking for ShardedDeviceArray should be # sufficient, but we use the more general `DeviceArray`. all( @@ -2413,7 +2415,7 @@ def _cpp_pmap( ### If we can use the fastpath, we return required info to the caller. if use_fastpath: - execute_replicated = execute + execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) out_handler = execute_replicated.out_handler in_handler = execute_replicated.in_handler out_indices = [tuple(s.devices_indices_map(a.shape).values()) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b4fd9b7a4..97e27283f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -962,7 +962,7 @@ def xla_pmap_impl_lazy( donated_invars: Sequence[bool], global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, -): +) -> Callable: if (config.jax_disable_jit and config.jax_eager_pmap and not is_explicit_global_axis_size and not any(d for d in donated_invars) and not all(g is not None for g in global_arg_shapes)): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 0f145e800..aa95cac79 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -280,6 +280,8 @@ class XlaExecutable(Executable): class XlaLowering(Lowering): """Adapts our various internal XLA-backed computations into a ``Lowering``.""" + compile_args: Dict[str, Any] + def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" raise NotImplementedError("must override")