mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another. One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially. PiperOrigin-RevId: 549301796
This commit is contained in:
parent
f97dca79a2
commit
cdb48134e5
@ -62,10 +62,12 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
|
||||
from jax._src import util
|
||||
|
||||
@ -1850,8 +1852,13 @@ def _cpp_pmap(
|
||||
|
||||
return out, fastpath_data
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg)
|
||||
if xla_extension_version >= 169:
|
||||
cpp_mapped_f = pmap_lib.pmap( # type: ignore
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
|
||||
pytree_registry=tree_util.default_registry)
|
||||
else:
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg) # type: ignore
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
@ -45,6 +45,7 @@ from jax._src import profiler
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.abstract_arrays import array_types
|
||||
@ -57,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -2845,7 +2847,12 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
|
||||
if xla_extension_version >= 169:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -61,7 +61,6 @@ from jax._src.lax.utils import (
|
||||
standard_named_shape_rule,
|
||||
standard_primitive,
|
||||
)
|
||||
from jax._src.lib import pytree
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -4257,7 +4256,7 @@ def infeed(token, shape=None, partitions=None):
|
||||
`token` is used to sequence infeed and outfeed effects.
|
||||
`partitions` may be specified inside a `sharded_jit` function.
|
||||
"""
|
||||
flat_shapes, treedef = pytree.flatten(shape)
|
||||
flat_shapes, treedef = tree_util.tree_flatten(shape)
|
||||
for shape in flat_shapes:
|
||||
if not isinstance(shape, ShapedArray):
|
||||
raise TypeError("shape argument to infeed must be a pytree of "
|
||||
@ -4323,7 +4322,7 @@ def outfeed(token, xs, partitions = None):
|
||||
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
||||
raise ValueError(f"'partitions' argument to outfeed should be a tuple, "
|
||||
f"got {partitions}")
|
||||
flat_xs, _ = pytree.flatten(xs)
|
||||
flat_xs, _ = tree_util.tree_flatten(xs)
|
||||
return outfeed_p.bind(token, *flat_xs, partitions=partitions)
|
||||
|
||||
def _outfeed_abstract_eval(token, *xs, partitions):
|
||||
|
@ -32,6 +32,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import api
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -52,6 +53,7 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
@ -253,11 +255,19 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
|
||||
return outs, fastpath_data
|
||||
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
if xla_extension_version >= 169:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, tree_util.default_registry, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
|
||||
@ -1194,8 +1204,13 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
if xla_extension_version >= 169:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.default_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
else:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
@ -26,6 +26,7 @@ import warnings
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -38,6 +39,16 @@ U = TypeVar("U", bound=type[Any])
|
||||
Leaf = Any
|
||||
PyTreeDef = pytree.PyTreeDef
|
||||
|
||||
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
|
||||
default_registry: Optional[pytree.PyTreeRegistry]
|
||||
if xla_extension_version >= 169:
|
||||
default_registry = pytree.default_registry()
|
||||
# Set __module__ and __name__, which allow this registry to be pickled by
|
||||
# reference.
|
||||
default_registry.__module__ = __name__
|
||||
default_registry.__name__ = "default_registry"
|
||||
else:
|
||||
default_registry = None
|
||||
|
||||
def tree_flatten(tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
@ -54,11 +65,15 @@ def tree_flatten(tree: Any,
|
||||
flattening step. It should return a boolean, with true stopping the
|
||||
traversal and the whole subtree being treated as a leaf, and false
|
||||
indicating the flattening should traverse the current object.
|
||||
|
||||
Returns:
|
||||
A pair where the first element is a list of leaf values and the second
|
||||
element is a treedef representing the structure of the flattened tree.
|
||||
"""
|
||||
return pytree.flatten(tree, is_leaf)
|
||||
if default_registry:
|
||||
return default_registry.flatten(tree, is_leaf)
|
||||
else:
|
||||
return pytree.flatten(tree, is_leaf) # type: ignore
|
||||
|
||||
|
||||
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
@ -68,8 +83,8 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
|
||||
Args:
|
||||
treedef: the treedef to reconstruct
|
||||
leaves: the iterable of leaves to use for reconstruction. The iterable
|
||||
must match the leaves of the treedef.
|
||||
leaves: the iterable of leaves to use for reconstruction. The iterable must
|
||||
match the leaves of the treedef.
|
||||
|
||||
Returns:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure
|
||||
@ -77,30 +92,48 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
"""
|
||||
return treedef.unflatten(leaves)
|
||||
|
||||
|
||||
def tree_leaves(tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> list[Leaf]:
|
||||
"""Gets the leaves of a pytree."""
|
||||
return pytree.flatten(tree, is_leaf)[0]
|
||||
if default_registry:
|
||||
return default_registry.flatten(tree, is_leaf)[0]
|
||||
else:
|
||||
return pytree.flatten(tree, is_leaf)[0] # type: ignore
|
||||
|
||||
|
||||
def tree_structure(tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTreeDef:
|
||||
is_leaf: Optional[Callable[[Any],
|
||||
bool]] = None) -> PyTreeDef:
|
||||
"""Gets the treedef for a pytree."""
|
||||
return pytree.flatten(tree, is_leaf)[1]
|
||||
if default_registry:
|
||||
return default_registry.flatten(tree, is_leaf)[1]
|
||||
else:
|
||||
return pytree.flatten(tree, is_leaf)[1] # type: ignore
|
||||
|
||||
|
||||
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
|
||||
"""Makes a tuple treedef from an iterable of child treedefs."""
|
||||
return pytree.tuple(list(treedefs))
|
||||
if default_registry:
|
||||
return pytree.tuple(default_registry, list(treedefs)) # type: ignore
|
||||
else:
|
||||
return pytree.tuple(list(treedefs)) # type: ignore
|
||||
|
||||
|
||||
|
||||
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
|
||||
return treedef.children()
|
||||
|
||||
|
||||
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
|
||||
return treedef.num_nodes == 1
|
||||
|
||||
|
||||
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
|
||||
return treedef.num_nodes == 1 and treedef.num_leaves == 1
|
||||
|
||||
|
||||
def all_leaves(iterable: Iterable[Any],
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> bool:
|
||||
"""Tests whether all elements in the given iterable are all leaves.
|
||||
@ -120,7 +153,10 @@ def all_leaves(iterable: Iterable[Any],
|
||||
A boolean indicating if all elements in the input are leaves.
|
||||
"""
|
||||
if is_leaf is None:
|
||||
return pytree.all_leaves(iterable)
|
||||
if default_registry:
|
||||
return pytree.all_leaves(default_registry, iterable) # type: ignore
|
||||
else:
|
||||
return pytree.all_leaves(iterable) # type: ignore
|
||||
else:
|
||||
lst = list(iterable)
|
||||
return lst == tree_leaves(lst, is_leaf)
|
||||
@ -140,17 +176,20 @@ def register_pytree_node(nodetype: type[T],
|
||||
nodetype: a Python type to treat as an internal pytree node.
|
||||
flatten_func: a function to be used during flattening, taking a value of
|
||||
type ``nodetype`` and returning a pair, with (1) an iterable for the
|
||||
children to be flattened recursively, and (2) some hashable auxiliary
|
||||
data to be stored in the treedef and to be passed to the
|
||||
``unflatten_func``.
|
||||
children to be flattened recursively, and (2) some hashable auxiliary data
|
||||
to be stored in the treedef and to be passed to the ``unflatten_func``.
|
||||
unflatten_func: a function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_func`` and stored in the treedef, and the
|
||||
unflattened children. The function should return an instance of
|
||||
``nodetype``.
|
||||
"""
|
||||
pytree.register_node(nodetype, flatten_func, unflatten_func)
|
||||
if default_registry:
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
else:
|
||||
pytree.register_node(nodetype, flatten_func, unflatten_func) # type: ignore
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
|
||||
|
||||
def register_pytree_node_class(cls: U) -> U:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -168,10 +207,13 @@ def register_pytree_node_class(cls: U) -> U:
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
"""
|
||||
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
|
||||
register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
|
||||
return cls
|
||||
|
||||
def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
|
||||
def tree_map(f: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
|
||||
"""Maps a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
@ -183,9 +225,9 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether
|
||||
the flattening should traverse the current object, or if it should be
|
||||
stopped immediately, with the whole subtree being treated as a leaf.
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
@ -209,13 +251,15 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||
|
||||
|
||||
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
def tree_transpose(outer_treedef: PyTreeDef,
|
||||
inner_treedef: PyTreeDef,
|
||||
|
||||
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef,
|
||||
pytree_to_transpose: Any) -> Any:
|
||||
"""Transform a tree having tree structure (outer, inner) into one having structure
|
||||
|
||||
(inner, outer).
|
||||
"""
|
||||
flat, treedef = tree_flatten(pytree_to_transpose)
|
||||
@ -225,11 +269,14 @@ def tree_transpose(outer_treedef: PyTreeDef,
|
||||
expected_treedef = outer_treedef.compose(inner_treedef)
|
||||
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
|
||||
iter_flat = iter(flat)
|
||||
lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
|
||||
lol = [
|
||||
[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)
|
||||
]
|
||||
transposed_lol = zip(*lol)
|
||||
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
|
||||
return tree_unflatten(inner_treedef, subtrees)
|
||||
|
||||
|
||||
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
|
||||
# sufficiently queryable that we can express _replace_nones. That may mean once
|
||||
# we have a flatten_one function.
|
||||
@ -241,6 +288,7 @@ _registry = {
|
||||
lambda keys, xs: dict(zip(keys, xs))),
|
||||
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
|
||||
}
|
||||
|
||||
def _replace_nones(sentinel, tree):
|
||||
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
|
||||
if tree is None:
|
||||
@ -251,7 +299,7 @@ def _replace_nones(sentinel, tree):
|
||||
children, metadata = handler.to_iter(tree)
|
||||
proc_children = [_replace_nones(sentinel, child) for child in children]
|
||||
return handler.from_iter(metadata, proc_children)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
elif isinstance(tree, tuple) and hasattr(tree, "_fields"):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
children = iter(tree)
|
||||
proc_children = [_replace_nones(sentinel, child) for child in children]
|
||||
@ -259,8 +307,10 @@ def _replace_nones(sentinel, tree):
|
||||
else:
|
||||
return tree
|
||||
|
||||
|
||||
no_initializer = object()
|
||||
|
||||
|
||||
@overload
|
||||
def tree_reduce(function: Callable[[T, Any], T],
|
||||
tree: Any,
|
||||
@ -268,6 +318,7 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_reduce(function: Callable[[T, Any], T],
|
||||
tree: Any,
|
||||
@ -275,6 +326,7 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
|
||||
...
|
||||
|
||||
|
||||
def tree_reduce(function: Callable[[T, Any], T],
|
||||
tree: Any,
|
||||
initializer: Any = no_initializer,
|
||||
@ -287,6 +339,7 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
def tree_all(tree: Any) -> bool:
|
||||
return all(tree_leaves(tree))
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
collections.OrderedDict,
|
||||
lambda x: (tuple(x.values()), tuple(x.keys())),
|
||||
@ -300,6 +353,7 @@ register_pytree_node(
|
||||
|
||||
class _HashableCallableShim:
|
||||
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
|
||||
|
||||
def __init__(self, fun):
|
||||
self.fun = fun
|
||||
|
||||
@ -353,7 +407,8 @@ class Partial(functools.partial):
|
||||
>>> call_func(Partial(jnp.add), 1, 2)
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a
|
||||
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in
|
||||
a
|
||||
``TypeError``.
|
||||
|
||||
Note that if the result of ``Partial`` is used in the context where the
|
||||
@ -366,6 +421,7 @@ class Partial(functools.partial):
|
||||
>>> call_func(print_zero) # doctest:+ELLIPSIS
|
||||
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
|
||||
"""
|
||||
|
||||
def __new__(klass, func, *args, **kw):
|
||||
# In Python 3.10+, if func is itself a functools.partial instance,
|
||||
# functools.partial.__new__ would merge the arguments of this Partial
|
||||
@ -513,6 +569,7 @@ def _equality_errors(path, t1, t2, is_leaf):
|
||||
|
||||
class _DeprecatedKeyPathEntry(NamedTuple):
|
||||
key: Any
|
||||
|
||||
def pprint(self) -> str:
|
||||
assert False # must override
|
||||
|
||||
@ -815,11 +872,16 @@ def _child_keys(pytree: Any) -> KeyPath:
|
||||
return tuple(FlattenedIndexKey(i) for i in range(num_children))
|
||||
|
||||
|
||||
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> Iterable[Callable[[str], ValueError]]:
|
||||
|
||||
def _prefix_error(
|
||||
key_path: KeyPath,
|
||||
prefix_tree: Any,
|
||||
full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> Iterable[Callable[[str], ValueError]]:
|
||||
# A leaf is a valid prefix of any tree:
|
||||
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)): return
|
||||
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)):
|
||||
return
|
||||
|
||||
# The subtrees may disagree because their roots are of different types:
|
||||
if type(prefix_tree) != type(full_tree):
|
||||
@ -881,8 +943,9 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
prefix_tree_meta_str = str(prefix_tree_meta)
|
||||
full_tree_meta_str = str(full_tree_meta)
|
||||
metadata_diff = textwrap.indent(
|
||||
'\n'.join(difflib.ndiff(prefix_tree_meta_str.splitlines(),
|
||||
full_tree_meta_str.splitlines())),
|
||||
"\n".join(
|
||||
difflib.ndiff(prefix_tree_meta_str.splitlines(),
|
||||
full_tree_meta_str.splitlines())),
|
||||
prefix=" ")
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different pytree metadata at key path\n"
|
||||
@ -909,14 +972,19 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
|
||||
# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py
|
||||
def _deprecate(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
warnings.warn(f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
|
||||
f"Use jax.tree_util.{f.__name__} instead.",
|
||||
category=FutureWarning, stacklevel=2)
|
||||
warnings.warn(
|
||||
f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
|
||||
f"Use jax.tree_util.{f.__name__} instead.",
|
||||
category=FutureWarning,
|
||||
stacklevel=2)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
prefix = "_deprecated_"
|
||||
if name.startswith(prefix):
|
||||
|
@ -520,8 +520,8 @@ from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
@ -626,7 +626,7 @@ def id_tap(tap_func,
|
||||
FutureWarning)
|
||||
|
||||
if result is not None:
|
||||
flat_results, result_treedef = pytree.flatten(result)
|
||||
flat_results, result_treedef = tree_util.tree_flatten(result)
|
||||
for r in flat_results:
|
||||
dispatch.check_arg(r)
|
||||
|
||||
@ -642,7 +642,7 @@ def id_tap(tap_func,
|
||||
# Return the results, but add a dependency on the call, to ensure it
|
||||
# is kept in the graph.
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
call_flat_results, _ = pytree.flatten(call_res)
|
||||
call_flat_results, _ = tree_util.tree_flatten(call_res)
|
||||
if call_flat_results:
|
||||
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
||||
for r in flat_results]
|
||||
@ -783,7 +783,7 @@ def _call(callback_func: Callable,
|
||||
_initialize_outfeed_receiver(
|
||||
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
|
||||
api.check_callable(callback_func)
|
||||
flat_args, arg_treedef = pytree.flatten(arg)
|
||||
flat_args, arg_treedef = tree_util.tree_flatten(arg)
|
||||
for arg in flat_args:
|
||||
dispatch.check_arg(arg)
|
||||
# See definition of outside_call_p for what parameters it takes
|
||||
@ -797,7 +797,7 @@ def _call(callback_func: Callable,
|
||||
|
||||
if not identity:
|
||||
# Turn abstract values into ShapesDtypeStruct
|
||||
flat_results_shape, result_treedef = pytree.flatten(result_shape)
|
||||
flat_results_shape, result_treedef = tree_util.tree_flatten(result_shape)
|
||||
try:
|
||||
flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.dtype(r, canonicalize=True))
|
||||
for r in flat_results_shape]
|
||||
@ -1316,7 +1316,7 @@ def _outside_call_run_callback(
|
||||
else: # Check the type of the callback results
|
||||
assert result_treedef is not None
|
||||
assert flat_results_aval is not None
|
||||
actual_flat_results, actual_result_treedef = pytree.flatten(res)
|
||||
actual_flat_results, actual_result_treedef = tree_util.tree_flatten(res)
|
||||
if actual_result_treedef != result_treedef:
|
||||
msg = (f"Callback func {callback} should have returned a result "
|
||||
f"with pytree {result_treedef} but returned "
|
||||
|
@ -413,9 +413,11 @@ class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters(*TREES)
|
||||
def testPickleRoundTrip(self, tree):
|
||||
treedef = tree_util.tree_structure(tree)
|
||||
leaves, treedef = tree_util.tree_flatten(tree)
|
||||
treedef_restored = pickle.loads(pickle.dumps(treedef))
|
||||
self.assertEqual(treedef, treedef_restored)
|
||||
reconstituted = treedef_restored.unflatten(leaves)
|
||||
self.assertEqual(tree, reconstituted)
|
||||
|
||||
def testDictKeysSortable(self):
|
||||
d = {"a": 1, 2: "b"}
|
||||
|
Loading…
x
Reference in New Issue
Block a user