Fix type annotation for tree_util.default_registry

This commit is contained in:
Jake VanderPlas 2023-08-16 15:07:48 -07:00
parent bb4daa38c2
commit 9aca944891
4 changed files with 7 additions and 10 deletions

View File

@ -1868,10 +1868,9 @@ def _cpp_pmap(
return out, fastpath_data
# TODO(jakevdp): remove ignore[arg-type] below once default_registry is always defined
cpp_mapped_f = pmap_lib.pmap( # type: ignore
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
pytree_registry=tree_util.default_registry) # type: ignore[arg-type]
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)

View File

@ -2794,7 +2794,7 @@ class MeshExecutable(stages.XlaExecutable):
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry) # type: ignore
tree_util.default_registry)
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -255,11 +255,11 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjit_f = xc._xla.pjit(
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.default_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun

View File

@ -38,8 +38,6 @@ U = TypeVar("U", bound=type[Any])
Leaf = Any
PyTreeDef = pytree.PyTreeDef
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
default_registry: pytree.PyTreeRegistry | None
default_registry = pytree.default_registry()
# Set __module__ and __name__, which allow this registry to be pickled by
# reference.