[JAX] Add support for multiple pytree registries.

We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
This commit is contained in:
Peter Hawkins 2023-07-19 06:47:46 -07:00 committed by jax authors
parent f97dca79a2
commit cdb48134e5
7 changed files with 145 additions and 47 deletions

View File

@ -62,10 +62,12 @@ from jax._src.api_util import (
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
from jax._src import util
@ -1850,8 +1852,13 @@ def _cpp_pmap(
return out, fastpath_data
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg)
if xla_extension_version >= 169:
cpp_mapped_f = pmap_lib.pmap( # type: ignore
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
pytree_registry=tree_util.default_registry)
else:
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg) # type: ignore
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)

View File

@ -45,6 +45,7 @@ from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
@ -57,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -2845,7 +2847,12 @@ class MeshExecutable(stages.XlaExecutable):
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
if xla_extension_version >= 169:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry) # type: ignore
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -61,7 +61,6 @@ from jax._src.lax.utils import (
standard_named_shape_rule,
standard_primitive,
)
from jax._src.lib import pytree
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
@ -4257,7 +4256,7 @@ def infeed(token, shape=None, partitions=None):
`token` is used to sequence infeed and outfeed effects.
`partitions` may be specified inside a `sharded_jit` function.
"""
flat_shapes, treedef = pytree.flatten(shape)
flat_shapes, treedef = tree_util.tree_flatten(shape)
for shape in flat_shapes:
if not isinstance(shape, ShapedArray):
raise TypeError("shape argument to infeed must be a pytree of "
@ -4323,7 +4322,7 @@ def outfeed(token, xs, partitions = None):
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
raise ValueError(f"'partitions' argument to outfeed should be a tuple, "
f"got {partitions}")
flat_xs, _ = pytree.flatten(xs)
flat_xs, _ = tree_util.tree_flatten(xs)
return outfeed_p.bind(token, *flat_xs, partitions=partitions)
def _outfeed_abstract_eval(token, *xs, partitions):

View File

@ -32,6 +32,7 @@ from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import traceback_util
from jax._src import api
from jax._src import xla_bridge as xb
@ -52,6 +53,7 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
@ -253,11 +255,19 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data
cpp_pjit_f = xc._xla.pjit( # type: ignore
if xla_extension_version >= 169:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
@ -1194,8 +1204,13 @@ def _pjit_call_impl(*args, jaxpr,
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
_get_cpp_global_cache(has_explicit_sharding))(*args)
if xla_extension_version >= 169:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.default_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
else:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -26,6 +26,7 @@ import warnings
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.util import unzip2
@ -38,6 +39,16 @@ U = TypeVar("U", bound=type[Any])
Leaf = Any
PyTreeDef = pytree.PyTreeDef
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
default_registry: Optional[pytree.PyTreeRegistry]
if xla_extension_version >= 169:
default_registry = pytree.default_registry()
# Set __module__ and __name__, which allow this registry to be pickled by
# reference.
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"
else:
default_registry = None
def tree_flatten(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
@ -54,11 +65,15 @@ def tree_flatten(tree: Any,
flattening step. It should return a boolean, with true stopping the
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
"""
return pytree.flatten(tree, is_leaf)
if default_registry:
return default_registry.flatten(tree, is_leaf)
else:
return pytree.flatten(tree, is_leaf) # type: ignore
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
@ -68,8 +83,8 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
Args:
treedef: the treedef to reconstruct
leaves: the iterable of leaves to use for reconstruction. The iterable
must match the leaves of the treedef.
leaves: the iterable of leaves to use for reconstruction. The iterable must
match the leaves of the treedef.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure
@ -77,30 +92,48 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
"""
return treedef.unflatten(leaves)
def tree_leaves(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> list[Leaf]:
"""Gets the leaves of a pytree."""
return pytree.flatten(tree, is_leaf)[0]
if default_registry:
return default_registry.flatten(tree, is_leaf)[0]
else:
return pytree.flatten(tree, is_leaf)[0] # type: ignore
def tree_structure(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTreeDef:
is_leaf: Optional[Callable[[Any],
bool]] = None) -> PyTreeDef:
"""Gets the treedef for a pytree."""
return pytree.flatten(tree, is_leaf)[1]
if default_registry:
return default_registry.flatten(tree, is_leaf)[1]
else:
return pytree.flatten(tree, is_leaf)[1] # type: ignore
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
"""Makes a tuple treedef from an iterable of child treedefs."""
return pytree.tuple(list(treedefs))
if default_registry:
return pytree.tuple(default_registry, list(treedefs)) # type: ignore
else:
return pytree.tuple(list(treedefs)) # type: ignore
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
return treedef.children()
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1 and treedef.num_leaves == 1
def all_leaves(iterable: Iterable[Any],
is_leaf: Optional[Callable[[Any], bool]] = None) -> bool:
"""Tests whether all elements in the given iterable are all leaves.
@ -120,7 +153,10 @@ def all_leaves(iterable: Iterable[Any],
A boolean indicating if all elements in the input are leaves.
"""
if is_leaf is None:
return pytree.all_leaves(iterable)
if default_registry:
return pytree.all_leaves(default_registry, iterable) # type: ignore
else:
return pytree.all_leaves(iterable) # type: ignore
else:
lst = list(iterable)
return lst == tree_leaves(lst, is_leaf)
@ -140,17 +176,20 @@ def register_pytree_node(nodetype: type[T],
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for the
children to be flattened recursively, and (2) some hashable auxiliary
data to be stored in the treedef and to be passed to the
``unflatten_func``.
children to be flattened recursively, and (2) some hashable auxiliary data
to be stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the
unflattened children. The function should return an instance of
``nodetype``.
"""
pytree.register_node(nodetype, flatten_func, unflatten_func)
if default_registry:
default_registry.register_node(nodetype, flatten_func, unflatten_func)
else:
pytree.register_node(nodetype, flatten_func, unflatten_func) # type: ignore
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
def register_pytree_node_class(cls: U) -> U:
"""Extends the set of types that are considered internal nodes in pytrees.
@ -168,10 +207,13 @@ def register_pytree_node_class(cls: U) -> U:
def tree_unflatten(cls, aux_data, children):
return cls(*children)
"""
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
return cls
def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
def tree_map(f: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
"""Maps a multi-input function over pytree args to produce a new pytree.
@ -183,9 +225,9 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether
the flattening should traverse the current object, or if it should be
stopped immediately, with the whole subtree being treated as a leaf.
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
@ -209,13 +251,15 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
return treedef.from_iterable_tree(xs)
def tree_transpose(outer_treedef: PyTreeDef,
inner_treedef: PyTreeDef,
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
pytree_to_transpose: Any) -> Any:
"""Transform a tree having tree structure (outer, inner) into one having structure
(inner, outer).
"""
flat, treedef = tree_flatten(pytree_to_transpose)
@ -225,11 +269,14 @@ def tree_transpose(outer_treedef: PyTreeDef,
expected_treedef = outer_treedef.compose(inner_treedef)
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
iter_flat = iter(flat)
lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
lol = [
[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)
]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
# sufficiently queryable that we can express _replace_nones. That may mean once
# we have a flatten_one function.
@ -241,6 +288,7 @@ _registry = {
lambda keys, xs: dict(zip(keys, xs))),
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
}
def _replace_nones(sentinel, tree):
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
if tree is None:
@ -251,7 +299,7 @@ def _replace_nones(sentinel, tree):
children, metadata = handler.to_iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
return handler.from_iter(metadata, proc_children)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
elif isinstance(tree, tuple) and hasattr(tree, "_fields"):
# handle namedtuple as a special case, based on heuristic
children = iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
@ -259,8 +307,10 @@ def _replace_nones(sentinel, tree):
else:
return tree
no_initializer = object()
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
@ -268,6 +318,7 @@ def tree_reduce(function: Callable[[T, Any], T],
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
...
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
@ -275,6 +326,7 @@ def tree_reduce(function: Callable[[T, Any], T],
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
...
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: Any = no_initializer,
@ -287,6 +339,7 @@ def tree_reduce(function: Callable[[T, Any], T],
def tree_all(tree: Any) -> bool:
return all(tree_leaves(tree))
register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
@ -300,6 +353,7 @@ register_pytree_node(
class _HashableCallableShim:
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
def __init__(self, fun):
self.fun = fun
@ -353,7 +407,8 @@ class Partial(functools.partial):
>>> call_func(Partial(jnp.add), 1, 2)
Array(3, dtype=int32, weak_type=True)
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in
a
``TypeError``.
Note that if the result of ``Partial`` is used in the context where the
@ -366,6 +421,7 @@ class Partial(functools.partial):
>>> call_func(print_zero) # doctest:+ELLIPSIS
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
"""
def __new__(klass, func, *args, **kw):
# In Python 3.10+, if func is itself a functools.partial instance,
# functools.partial.__new__ would merge the arguments of this Partial
@ -513,6 +569,7 @@ def _equality_errors(path, t1, t2, is_leaf):
class _DeprecatedKeyPathEntry(NamedTuple):
key: Any
def pprint(self) -> str:
assert False # must override
@ -815,11 +872,16 @@ def _child_keys(pytree: Any) -> KeyPath:
return tuple(FlattenedIndexKey(i) for i in range(num_children))
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Callable[[str], ValueError]]:
def _prefix_error(
key_path: KeyPath,
prefix_tree: Any,
full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Callable[[str], ValueError]]:
# A leaf is a valid prefix of any tree:
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)): return
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)):
return
# The subtrees may disagree because their roots are of different types:
if type(prefix_tree) != type(full_tree):
@ -881,8 +943,9 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
prefix_tree_meta_str = str(prefix_tree_meta)
full_tree_meta_str = str(full_tree_meta)
metadata_diff = textwrap.indent(
'\n'.join(difflib.ndiff(prefix_tree_meta_str.splitlines(),
full_tree_meta_str.splitlines())),
"\n".join(
difflib.ndiff(prefix_tree_meta_str.splitlines(),
full_tree_meta_str.splitlines())),
prefix=" ")
yield lambda name: ValueError(
"pytree structure error: different pytree metadata at key path\n"
@ -909,14 +972,19 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py
def _deprecate(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
warnings.warn(f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
f"Use jax.tree_util.{f.__name__} instead.",
category=FutureWarning, stacklevel=2)
warnings.warn(
f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
f"Use jax.tree_util.{f.__name__} instead.",
category=FutureWarning,
stacklevel=2)
return f(*args, **kwargs)
return wrapped
def __getattr__(name):
prefix = "_deprecated_"
if name.startswith(prefix):

View File

@ -520,8 +520,8 @@ from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
from jax._src.lib import pytree
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
@ -626,7 +626,7 @@ def id_tap(tap_func,
FutureWarning)
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
flat_results, result_treedef = tree_util.tree_flatten(result)
for r in flat_results:
dispatch.check_arg(r)
@ -642,7 +642,7 @@ def id_tap(tap_func,
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
if FLAGS.jax_host_callback_ad_transforms:
call_flat_results, _ = pytree.flatten(call_res)
call_flat_results, _ = tree_util.tree_flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
@ -783,7 +783,7 @@ def _call(callback_func: Callable,
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
api.check_callable(callback_func)
flat_args, arg_treedef = pytree.flatten(arg)
flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args:
dispatch.check_arg(arg)
# See definition of outside_call_p for what parameters it takes
@ -797,7 +797,7 @@ def _call(callback_func: Callable,
if not identity:
# Turn abstract values into ShapesDtypeStruct
flat_results_shape, result_treedef = pytree.flatten(result_shape)
flat_results_shape, result_treedef = tree_util.tree_flatten(result_shape)
try:
flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.dtype(r, canonicalize=True))
for r in flat_results_shape]
@ -1316,7 +1316,7 @@ def _outside_call_run_callback(
else: # Check the type of the callback results
assert result_treedef is not None
assert flat_results_aval is not None
actual_flat_results, actual_result_treedef = pytree.flatten(res)
actual_flat_results, actual_result_treedef = tree_util.tree_flatten(res)
if actual_result_treedef != result_treedef:
msg = (f"Callback func {callback} should have returned a result "
f"with pytree {result_treedef} but returned "

View File

@ -413,9 +413,11 @@ class TreeTest(jtu.JaxTestCase):
@parameterized.parameters(*TREES)
def testPickleRoundTrip(self, tree):
treedef = tree_util.tree_structure(tree)
leaves, treedef = tree_util.tree_flatten(tree)
treedef_restored = pickle.loads(pickle.dumps(treedef))
self.assertEqual(treedef, treedef_restored)
reconstituted = treedef_restored.unflatten(leaves)
self.assertEqual(tree, reconstituted)
def testDictKeysSortable(self):
d = {"a": 1, 2: "b"}