mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix type annotation for tree_util.default_registry
This commit is contained in:
parent
bb4daa38c2
commit
9aca944891
@ -1868,10 +1868,9 @@ def _cpp_pmap(
|
||||
|
||||
return out, fastpath_data
|
||||
|
||||
# TODO(jakevdp): remove ignore[arg-type] below once default_registry is always defined
|
||||
cpp_mapped_f = pmap_lib.pmap( # type: ignore
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
|
||||
pytree_registry=tree_util.default_registry) # type: ignore[arg-type]
|
||||
pytree_registry=tree_util.default_registry)
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
@ -2794,7 +2794,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry) # type: ignore
|
||||
tree_util.default_registry)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -255,11 +255,11 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
|
||||
return outs, fastpath_data
|
||||
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, tree_util.default_registry, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.default_registry,
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
|
@ -38,8 +38,6 @@ U = TypeVar("U", bound=type[Any])
|
||||
Leaf = Any
|
||||
PyTreeDef = pytree.PyTreeDef
|
||||
|
||||
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
|
||||
default_registry: pytree.PyTreeRegistry | None
|
||||
default_registry = pytree.default_registry()
|
||||
# Set __module__ and __name__, which allow this registry to be pickled by
|
||||
# reference.
|
||||
|
Loading…
x
Reference in New Issue
Block a user