mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge branch 'google:main' into keypath-log
This commit is contained in:
commit
1a4527ed40
@ -23,7 +23,7 @@ import jax
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, keystr
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -407,7 +407,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
if v in res_vars:
|
||||
if arg_info is not None:
|
||||
arg_name, arg_path = arg_info[i]
|
||||
src = f'from the argument {arg_name}{arg_path.pprint("")}'
|
||||
src = f'from the argument {arg_name}{keystr(arg_path)}'
|
||||
else:
|
||||
src = 'from the argument at flattened index {i}'
|
||||
results.append((v.aval, src))
|
||||
|
@ -40,7 +40,7 @@ from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
Partial, PyTreeDef, all_leaves)
|
||||
Partial, PyTreeDef, all_leaves, keystr)
|
||||
from jax._src import callback as jcb
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
@ -1815,15 +1815,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
except (TypeError, ValueError):
|
||||
ba = None
|
||||
if ba is None:
|
||||
args_paths = [f'args{p.pprint("")} '
|
||||
args_paths = [f'args{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
for p, x in _generate_key_paths(args)]
|
||||
kwargs_paths = [f'kwargs{p.pprint("")} '
|
||||
kwargs_paths = [f'kwargs{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
for p, x in _generate_key_paths(kwargs)]
|
||||
key_paths = [*args_paths, *kwargs_paths]
|
||||
else:
|
||||
key_paths = [f'argument {name}{p.pprint("")} '
|
||||
key_paths = [f'argument {name}{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
for name, arg in ba.arguments.items()
|
||||
for p, x in _generate_key_paths(arg)]
|
||||
|
@ -475,10 +475,7 @@ class ArrayImpl(basearray.Array):
|
||||
self._check_if_deleted()
|
||||
if self._npy_value is None:
|
||||
if self.is_fully_replicated:
|
||||
arr = self._arrays[0] # type: ignore
|
||||
# copy_to_host_async implemented in c++ only for single device arrays.
|
||||
if hasattr(arr, "_copy_single_device_array_to_host_async"):
|
||||
arr._copy_single_device_array_to_host_async() # type: ignore
|
||||
self._arrays[0].copy_to_host_async()
|
||||
return
|
||||
try:
|
||||
self.addressable_shards[0].replica_id
|
||||
@ -496,12 +493,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
if self._npy_value is None:
|
||||
if self.is_fully_replicated:
|
||||
arr = self._arrays[0] # type: ignore
|
||||
# Conversion to numpy implemented only for single device arrays.
|
||||
if hasattr(arr, "_single_device_array_to_np_array"):
|
||||
self._npy_value = arr._single_device_array_to_np_array() # type: ignore
|
||||
else:
|
||||
self._npy_value = np.asarray(arr) # type: ignore
|
||||
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
|
||||
self._npy_value.flags.writeable = False
|
||||
return cast(np.ndarray, self._npy_value)
|
||||
|
||||
|
@ -16,6 +16,7 @@ from collections import Counter
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
@ -263,6 +264,10 @@ class GlobalDeviceArray:
|
||||
device_buffers: Sequence[DeviceArray],
|
||||
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
|
||||
_enable_checks: bool = True):
|
||||
warnings.warn(
|
||||
"GlobalDeviceArray has been deprecated. Please migrate to jax.Array. "
|
||||
"See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration "
|
||||
"on how to migrate to jax.Array.", DeprecationWarning)
|
||||
self._global_shape = global_shape
|
||||
self._global_mesh = global_mesh
|
||||
self._mesh_axes = mesh_axes
|
||||
|
@ -49,7 +49,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
from jax.tree_util import tree_flatten, tree_map, keystr
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
@ -3058,7 +3058,7 @@ def lower_sharding_computation(
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
|
||||
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
|
||||
if i in kept_var_idx]
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
@ -3249,7 +3249,7 @@ def lower_mesh_computation(
|
||||
closed_jaxpr.effects))
|
||||
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)]
|
||||
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)]
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
|
@ -148,10 +148,10 @@ def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name):
|
||||
|
||||
arg_list = []
|
||||
for arg_key, val in args_aug:
|
||||
ak, *rem_keys = arg_key.keys
|
||||
ak, *rem_keys = arg_key
|
||||
if sig is not None:
|
||||
loc = ''.join(k.pprint() for k in rem_keys)
|
||||
arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}'
|
||||
loc = ''.join(str(k) for k in rem_keys)
|
||||
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
|
||||
else:
|
||||
arg_name = ''
|
||||
da = val.sharding._device_assignment if hasattr(val, 'sharding') else None
|
||||
|
@ -13,20 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
import difflib
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, Hashable, Iterable, List, NamedTuple,
|
||||
Optional, Sequence, Tuple, Type, TypeVar, overload)
|
||||
import textwrap
|
||||
from typing import (Any, Callable, Hashable, Iterable, List, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
import warnings
|
||||
|
||||
from jax._src.lib import pytree
|
||||
|
||||
from jax._src.util import safe_zip, unzip2
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import unzip2
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -291,7 +293,6 @@ register_pytree_node(
|
||||
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index,call-overload]
|
||||
|
||||
|
||||
|
||||
class _HashableCallableShim:
|
||||
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
|
||||
def __init__(self, fun):
|
||||
@ -398,85 +399,290 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
return result
|
||||
|
||||
def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
handler = _registry.get(type(pytree))
|
||||
if handler:
|
||||
children, meta = handler.to_iter(pytree)
|
||||
return list(children), meta
|
||||
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
|
||||
return list(pytree), None
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
return [getattr(pytree, s) for s in pytree._fields], None
|
||||
else:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> List[Callable[[str], ValueError]]:
|
||||
return list(_prefix_error(KeyPath(()), prefix_tree, full_tree, is_leaf))
|
||||
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
|
||||
|
||||
class KeyPathEntry(NamedTuple):
|
||||
# TODO(ivyzheng): Remove old APIs when all users migrated.
|
||||
|
||||
class _DeprecatedKeyPathEntry(NamedTuple):
|
||||
key: Any
|
||||
def pprint(self) -> str:
|
||||
assert False # must override
|
||||
|
||||
class KeyPath(NamedTuple):
|
||||
keys: Tuple[KeyPathEntry, ...]
|
||||
def __add__(self, other):
|
||||
if isinstance(other, KeyPathEntry):
|
||||
return KeyPath(self.keys + (other,))
|
||||
raise TypeError(type(other))
|
||||
def pprint(self, root: str = ' tree root') -> str:
|
||||
if not self.keys:
|
||||
return root
|
||||
return ''.join(k.pprint() for k in self.keys)
|
||||
|
||||
class GetitemKeyPathEntry(KeyPathEntry):
|
||||
class GetitemKeyPathEntry(_DeprecatedKeyPathEntry):
|
||||
def pprint(self) -> str:
|
||||
return f'[{repr(self.key)}]'
|
||||
def __str__(self):
|
||||
return self.pprint()
|
||||
|
||||
class AttributeKeyPathEntry(KeyPathEntry):
|
||||
class AttributeKeyPathEntry(_DeprecatedKeyPathEntry):
|
||||
def pprint(self) -> str:
|
||||
return f'.{self.key}'
|
||||
def __str__(self):
|
||||
return self.pprint()
|
||||
|
||||
class FlattenedKeyPathEntry(KeyPathEntry): # fallback
|
||||
class FlattenedKeyPathEntry(_DeprecatedKeyPathEntry): # fallback
|
||||
def pprint(self) -> str:
|
||||
return f'[<flat index {self.key}>]'
|
||||
def __str__(self):
|
||||
return self.pprint()
|
||||
|
||||
def _child_keys(pytree: Any) -> Sequence[KeyPathEntry]:
|
||||
assert not treedef_is_strict_leaf(tree_structure(pytree))
|
||||
handler = _keypath_registry.get(type(pytree))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SequenceKey():
|
||||
idx: int
|
||||
def __str__(self):
|
||||
return f'[{repr(self.idx)}]'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DictKey():
|
||||
key: Hashable
|
||||
def __str__(self):
|
||||
return f'[{repr(self.key)}]'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetAttrKey():
|
||||
name: str
|
||||
def __str__(self):
|
||||
return f'.{self.name}'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FlattenedIndexKey():
|
||||
key: int
|
||||
def __str__(self):
|
||||
return f'[<flat index {self.key}>]'
|
||||
|
||||
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
|
||||
|
||||
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
|
||||
KeyPath = Tuple[KeyEntry, ...]
|
||||
|
||||
def keystr(keys: KeyPath):
|
||||
return ''.join([str(k) for k in keys])
|
||||
|
||||
|
||||
class _RegistryWithKeypathsEntry(NamedTuple):
|
||||
flatten_with_keys: Callable[..., Any]
|
||||
unflatten_func: Callable[..., Any]
|
||||
|
||||
|
||||
def register_keypaths(
|
||||
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
|
||||
) -> None:
|
||||
"""[Deprecated] Register the method to get keypaths for type.
|
||||
|
||||
Please use ``register_pytree_with_keys`` instead.
|
||||
|
||||
Only works if the type was already registered with ``register_pytree_node``.
|
||||
"""
|
||||
warnings.warn(
|
||||
(
|
||||
"jax.tree_util.register_keypaths is deprecated, and will be removed"
|
||||
" in a future release. Please use `register_pytree_with_keys()`"
|
||||
" instead."
|
||||
),
|
||||
category=FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_register_keypaths(ty, handler)
|
||||
|
||||
|
||||
def _register_keypaths(
|
||||
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
|
||||
) -> None:
|
||||
def flatten_with_keys(xs):
|
||||
children, treedef = _registry[ty].to_iter(xs)
|
||||
return list(zip(handler(xs), children)), treedef
|
||||
if ty in _registry:
|
||||
_registry_with_keypaths[ty] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, _registry[ty].from_iter
|
||||
)
|
||||
|
||||
|
||||
_registry_with_keypaths = {}
|
||||
|
||||
_register_keypaths(
|
||||
tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
|
||||
)
|
||||
_register_keypaths(
|
||||
list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
|
||||
)
|
||||
_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs)))
|
||||
|
||||
_register_keypaths(
|
||||
collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys())
|
||||
)
|
||||
|
||||
_register_keypaths(
|
||||
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
|
||||
)
|
||||
|
||||
def register_pytree_with_keys(
|
||||
nodetype: Type[T],
|
||||
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
This is a more powerful alternative to ``register_pytree_node`` that allows
|
||||
you to access each pytree leaf's key path when flattening and tree-mapping.
|
||||
|
||||
Args:
|
||||
nodetype: a Python type to treat as an internal pytree node.
|
||||
flatten_func: a function to be used during flattening, taking a value of
|
||||
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
|
||||
each key path and its child, and (2) some hashable auxiliary data to be
|
||||
stored in the treedef and to be passed to the ``unflatten_func``.
|
||||
unflatten_func: a function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_func`` and stored in the treedef, and the
|
||||
unflattened children. The function should return an instance of
|
||||
``nodetype``.
|
||||
"""
|
||||
def flatten_func(tree):
|
||||
key_children, treedef = flatten_with_keys(tree)
|
||||
return [c for _, c in key_children], treedef
|
||||
register_pytree_node(nodetype, flatten_func, unflatten_func)
|
||||
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, unflatten_func
|
||||
)
|
||||
|
||||
def register_pytree_with_keys_class(cls: U) -> U:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
This function is similar to ``register_pytree_node_class``, but requires a
|
||||
class that defines how it could be flattened with keys.
|
||||
|
||||
It is a thin wrapper around ``register_pytree_with_keys``, and
|
||||
provides a class-oriented interface::
|
||||
|
||||
@register_pytree_with_keys_class
|
||||
class Special:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def tree_flatten_with_keys(self):
|
||||
return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
"""
|
||||
register_pytree_with_keys(
|
||||
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten
|
||||
)
|
||||
return cls
|
||||
|
||||
|
||||
def tree_flatten_with_path(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> Tuple[List[Tuple[KeyPath, Any]], PyTreeDef]:
|
||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten. If it contains a custom type, it must be
|
||||
registered with ``register_pytree_with_keys``.
|
||||
Returns:
|
||||
A pair which the first element is a list of key-leaf pairs, each of
|
||||
which contains a leaf and its key path. The second element is a treedef
|
||||
representing the structure of the flattened tree.
|
||||
"""
|
||||
_, tree_def = tree_flatten(tree, is_leaf)
|
||||
return _generate_key_paths(tree, is_leaf), tree_def
|
||||
|
||||
|
||||
def _generate_key_paths(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
return list(_generate_key_paths_((), tree, is_leaf))
|
||||
|
||||
|
||||
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> Iterable[Tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
handler = _registry_with_keypaths.get(type(tree))
|
||||
if handler:
|
||||
return handler(pytree)
|
||||
key_children, _ = handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
|
||||
elif tree is not None: # Some strictly leaf type, like int or numpy array
|
||||
yield key_path, tree
|
||||
|
||||
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
tree: Any, *rest: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
|
||||
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
||||
|
||||
This is a more powerful alternative of ``tree_map`` that can take the key path
|
||||
of each leaf as input argument as well.
|
||||
|
||||
Args:
|
||||
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
||||
each corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf's key path as the first
|
||||
positional argument and the leaf itself as the second argument to ``f``.
|
||||
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||
the tuple of values at corresponding nodes in ``rest``.
|
||||
"""
|
||||
|
||||
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
|
||||
keypath_leaves = list(zip(*keypath_leaves))
|
||||
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
|
||||
|
||||
|
||||
def _child_keys(pytree: Any) -> KeyPath:
|
||||
assert not treedef_is_strict_leaf(tree_structure(pytree))
|
||||
handler = _registry_with_keypaths.get(type(pytree))
|
||||
if handler:
|
||||
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])
|
||||
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
return [AttributeKeyPathEntry(s) for s in pytree._fields]
|
||||
return tuple(GetAttrKey(s) for s in pytree._fields)
|
||||
else:
|
||||
num_children = len(treedef_children(tree_structure(pytree)))
|
||||
return [FlattenedKeyPathEntry(i) for i in range(num_children)]
|
||||
return tuple(FlattenedIndexKey(i) for i in range(num_children))
|
||||
|
||||
_keypath_registry: Dict[Type, Callable[[Any], Sequence[KeyPathEntry]]] = {}
|
||||
|
||||
def register_keypaths(ty: Type, handler: Callable[[Any], Sequence[KeyPathEntry]]
|
||||
) -> None:
|
||||
_keypath_registry[ty] = handler
|
||||
|
||||
register_keypaths(tuple,
|
||||
lambda tup: [GetitemKeyPathEntry(i) for i in range(len(tup))])
|
||||
register_keypaths(list,
|
||||
lambda lst: [GetitemKeyPathEntry(i) for i in range(len(lst))])
|
||||
register_keypaths(dict,
|
||||
lambda dct: [GetitemKeyPathEntry(k) for k in sorted(dct)])
|
||||
|
||||
def _generate_key_paths(tree: Any) -> List[Tuple[KeyPath, Any]]:
|
||||
return list(_generate_key_paths_(KeyPath(()), tree))
|
||||
|
||||
def _generate_key_paths_(key_path: KeyPath, tree: Any
|
||||
) -> Iterable[Tuple[KeyPath, Any]]:
|
||||
if treedef_is_strict_leaf(tree_structure(tree)):
|
||||
yield key_path, tree
|
||||
else:
|
||||
child_keys = _child_keys(tree)
|
||||
tree_children, _ = flatten_one_level(tree)
|
||||
for k, c in zip(child_keys, tree_children):
|
||||
yield from _generate_key_paths_(key_path + k, c)
|
||||
|
||||
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
@ -488,7 +694,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
if type(prefix_tree) != type(full_tree):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different types at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"but at the same key path the full pytree has a subtree of different type\n"
|
||||
@ -510,7 +716,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
ty = type(prefix_tree)
|
||||
yield lambda name: ValueError(
|
||||
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type "
|
||||
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
|
||||
f"has a subtree of the same type but of length {len(full_tree)}."
|
||||
@ -525,7 +731,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
if len(prefix_tree_children) != len(full_tree_children):
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different numbers of pytree children at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with {len(prefix_tree_children)} child keys\n"
|
||||
@ -549,7 +755,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
prefix=" ")
|
||||
yield lambda name: ValueError(
|
||||
"pytree structure error: different pytree metadata at key path\n"
|
||||
f" {{name}}{key_path.pprint()}\n"
|
||||
f" {{name}}{keystr(key_path)}\n"
|
||||
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
|
||||
f" {type(prefix_tree)}\n"
|
||||
f"with metadata\n"
|
||||
@ -567,7 +773,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
|
||||
("equal pytree nodes gave differing prefix_tree_keys: "
|
||||
f"{prefix_tree_keys} and {full_tree_keys}")
|
||||
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
|
||||
yield from _prefix_error(key_path + k, t1, t2)
|
||||
yield from _prefix_error(tuple((*key_path, k)), t1, t2)
|
||||
|
||||
|
||||
# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py
|
||||
|
@ -48,7 +48,7 @@ from jax.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_leaves)
|
||||
tree_structure, tree_leaves, keystr)
|
||||
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
|
||||
_generate_key_paths, KeyPath)
|
||||
|
||||
@ -130,7 +130,7 @@ SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
|
||||
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
|
||||
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
|
||||
prefix = 'in' if error_type == SpecErrorType.input else 'out'
|
||||
msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, "
|
||||
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
|
||||
for key, x in _generate_key_paths(specs) if not isinstance(x, P)]
|
||||
raise TypeError(
|
||||
f"shard_map {prefix}_specs argument must be a pytree of "
|
||||
@ -169,15 +169,15 @@ def _spec_rank_error(
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if error_type == SpecErrorType.input and ba is not None:
|
||||
arg_key, *_ = fail_key.keys
|
||||
extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
|
||||
arg_key, *_ = fail_key
|
||||
extra = (f", where {base}[{arg_key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
|
||||
else:
|
||||
extra = ""
|
||||
msgs.append(
|
||||
f"{prefix}_specs{spec_key.pprint()} is {spec} which has length "
|
||||
f"{prefix}_specs{keystr(spec_key)} is {spec} which has length "
|
||||
f"{len(spec)}, but "
|
||||
f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, "
|
||||
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
|
||||
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
|
||||
@ -197,9 +197,9 @@ def _spec_divisibility_error(
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if ba is not None:
|
||||
arg_key, *_ = fail_key.keys
|
||||
extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
|
||||
arg_key, *_ = fail_key
|
||||
extra = (f", where args[{arg_key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
|
||||
names = _canonicalize_spec(spec)
|
||||
for d, ns in names.items():
|
||||
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
|
||||
@ -207,8 +207,8 @@ def _spec_divisibility_error(
|
||||
total = 'total ' if len(ns) > 1 else ''
|
||||
sz = math.prod(mesh.shape[n] for n in ns)
|
||||
msgs.append(
|
||||
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
|
||||
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
|
||||
f"args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
|
||||
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
|
||||
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
|
||||
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
|
||||
f"{aval.shape[d]}")
|
||||
@ -237,14 +237,14 @@ def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
got_rep = ','.join(map(str, rep))
|
||||
diff = ','.join(map(str, unmentioned - rep))
|
||||
msgs.append(
|
||||
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
|
||||
f"out_specs{keystr(spec_key)} is {spec} which implies that the "
|
||||
f"corresponding output value is replicated across mesh axes "
|
||||
f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, "
|
||||
f"which is missing the required axes {diff}")
|
||||
else:
|
||||
need_rep_, = unmentioned
|
||||
msgs.append(
|
||||
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
|
||||
f"out_specs{keystr(spec_key)} is {spec} which implies that the "
|
||||
f"corresponding output value is replicated across mesh axis "
|
||||
f"'{need_rep_}', but could not infer replication over any axes")
|
||||
assert msgs
|
||||
|
@ -43,7 +43,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, DebugInfo)
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
KeyPath, _generate_key_paths)
|
||||
KeyPath, _generate_key_paths, keystr)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
@ -1493,7 +1493,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
arg_info = arg_info_all(dbg)
|
||||
if invar_pos and arg_info:
|
||||
arg_info = [arg_info[i] for i in invar_pos]
|
||||
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
|
||||
arg_names = [f'{name}{keystr(path)}' for name, path in arg_info]
|
||||
if len(arg_names) == 1:
|
||||
arg_info_str = f"the argument {arg_names[0]}"
|
||||
elif len(arg_names) == 2:
|
||||
|
@ -90,7 +90,6 @@ from jax._src.interpreters.pxla import (
|
||||
lower_mesh_computation as lower_mesh_computation,
|
||||
lower_parallel_callable as lower_parallel_callable,
|
||||
lower_sharding_computation as lower_sharding_computation,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
maybe_extend_axis_env as maybe_extend_axis_env,
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
@ -132,6 +131,7 @@ from jax._src.interpreters.pxla import (
|
||||
from jax._src.interpreters.pxla import (
|
||||
Mesh as _deprecated_Mesh,
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
make_sharded_device_array as _deprecated_make_sharded_device_array,
|
||||
)
|
||||
|
||||
import typing
|
||||
@ -139,6 +139,7 @@ if typing.TYPE_CHECKING:
|
||||
from jax._src.interpreters.pxla import (
|
||||
Mesh as Mesh,
|
||||
PartitionSpec as PartitionSpec,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
)
|
||||
del typing
|
||||
|
||||
@ -149,10 +150,21 @@ _deprecations = {
|
||||
_deprecated_Mesh,
|
||||
),
|
||||
"PartitionSpec": (
|
||||
("jax.interpreters.pxla.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."),
|
||||
(
|
||||
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."
|
||||
),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
|
||||
# is the default since November 2022.
|
||||
"make_sharded_device_array": (
|
||||
(
|
||||
"jax.interpreters.pxla.make_sharded_device_array is deprecated as"
|
||||
" of March 3, 2023. Use jax.make_array_from_single_device_arrays."
|
||||
),
|
||||
_deprecated_make_sharded_device_array,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
@ -56,7 +56,17 @@ from jax._src.tree_util import (
|
||||
treedef_children as treedef_children,
|
||||
treedef_is_leaf as treedef_is_leaf,
|
||||
treedef_tuple as treedef_tuple,
|
||||
# TODO(ivyzheng): Remove old APIs when all users migrated.
|
||||
register_keypaths as register_keypaths,
|
||||
AttributeKeyPathEntry as AttributeKeyPathEntry,
|
||||
GetitemKeyPathEntry as GetitemKeyPathEntry,
|
||||
register_pytree_with_keys as register_pytree_with_keys,
|
||||
register_pytree_with_keys_class as register_pytree_with_keys_class,
|
||||
tree_map_with_path as tree_map_with_path,
|
||||
tree_flatten_with_path as tree_flatten_with_path,
|
||||
keystr as keystr,
|
||||
SequenceKey as SequenceKey,
|
||||
DictKey as DictKey,
|
||||
GetAttrKey as GetAttrKey,
|
||||
FlattenedIndexKey as FlattenedIndexKey,
|
||||
)
|
||||
|
@ -19,13 +19,14 @@ load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
pytype_library(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
"ducc_fft.py",
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
@ -34,13 +34,14 @@ _C2C = 0
|
||||
_C2R = 1
|
||||
_R2C = 2
|
||||
|
||||
def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
fft_lengths: List[int]) -> bytes:
|
||||
|
||||
def _ducc_fft_descriptor(
|
||||
shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]
|
||||
) -> Tuple[bytes, np.dtype, List[int]]:
|
||||
n = len(shape)
|
||||
assert len(fft_lengths) >= 1
|
||||
assert len(fft_lengths) <= n, (fft_lengths, n)
|
||||
|
||||
|
||||
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
||||
is_double = np.finfo(dtype).dtype == np.float64
|
||||
if fft_type == FftType.RFFT:
|
||||
|
@ -23,14 +23,14 @@ from .hlo_helpers import custom_call
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _linalg as _cuda_linalg
|
||||
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_linalg = None
|
||||
|
||||
try:
|
||||
from .rocm import _linalg as _hip_linalg
|
||||
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error
|
||||
for _name, _value in _hip_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
|
@ -26,14 +26,14 @@ from jaxlib import xla_client
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _prng as _cuda_prng
|
||||
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_prng = None
|
||||
|
||||
try:
|
||||
from .rocm import _prng as _hip_prng
|
||||
from .rocm import _prng as _hip_prng # pytype: disable=import-error
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _rnn as _rnn
|
||||
from .cuda import _rnn # pytype: disable=import-error
|
||||
for _name, _value in _rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
|
||||
except ImportError:
|
||||
|
@ -26,14 +26,14 @@ from jaxlib import xla_client
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cublas = None
|
||||
|
||||
try:
|
||||
from .cuda import _solver as _cusolver
|
||||
from .cuda import _solver as _cusolver # pytype: disable=import-error
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
@ -41,14 +41,14 @@ except ImportError:
|
||||
|
||||
|
||||
try:
|
||||
from .rocm import _blas as _hipblas
|
||||
from .rocm import _blas as _hipblas # pytype: disable=import-error
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
_hipblas = None
|
||||
|
||||
try:
|
||||
from .rocm import _solver as _hipsolver
|
||||
from .rocm import _solver as _hipsolver # pytype: disable=import-error
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
|
@ -26,7 +26,7 @@ from jaxlib import xla_client
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _sparse as _cusparse
|
||||
from .cuda import _sparse as _cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
_cusparse = None
|
||||
else:
|
||||
@ -34,7 +34,7 @@ else:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
try:
|
||||
from .rocm import _sparse as _hipsparse
|
||||
from .rocm import _sparse as _hipsparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
_hipsparse = None
|
||||
else:
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
|
||||
def custom_call(
|
||||
call_target_name: str,
|
||||
call_target_name: Union[str, bytes],
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
operand_layouts: Sequence[Sequence[int]],
|
||||
|
@ -26,5 +26,7 @@ filterwarnings =
|
||||
default:Error writing persistent compilation cache entry for 'jit__lambda_'
|
||||
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
|
||||
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
|
||||
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
|
||||
ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
@ -3745,7 +3745,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
error = re.escape(
|
||||
"pytree structure error: different lengths of list at "
|
||||
"key path\n"
|
||||
" pjit out_shardings tree root\n")
|
||||
" pjit out_shardings\n")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message
|
||||
|
||||
|
@ -24,7 +24,7 @@ import jax
|
||||
from jax import tree_util
|
||||
from jax import flatten_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.tree_util import prefix_errors
|
||||
from jax._src.tree_util import prefix_errors, flatten_one_level
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -53,8 +53,11 @@ class AnObject:
|
||||
def __repr__(self):
|
||||
return f"AnObject({self.x},{self.y},{self.z})"
|
||||
|
||||
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
|
||||
lambda z, xy: AnObject(xy[0], xy[1], z))
|
||||
tree_util.register_pytree_with_keys(
|
||||
AnObject,
|
||||
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
|
||||
lambda z, xy: AnObject(xy[0], xy[1], z), # unflatten (no key involved)
|
||||
)
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class Special:
|
||||
@ -75,6 +78,14 @@ class Special:
|
||||
def __eq__(self, other):
|
||||
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
|
||||
|
||||
|
||||
@tree_util.register_pytree_with_keys_class
|
||||
class SpecialWithKeys(Special):
|
||||
def tree_flatten_with_keys(self):
|
||||
return (((tree_util.GetAttrKey('x'), self.x),
|
||||
(tree_util.GetAttrKey('y'), self.y)), None)
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class FlatCache:
|
||||
def __init__(self, structured, *, leaves=None, treedef=None):
|
||||
@ -160,6 +171,25 @@ LEAVES = (
|
||||
(object(),),
|
||||
)
|
||||
|
||||
# All except those decorated by register_pytree_node_class
|
||||
TREES_WITH_KEYPATH = (
|
||||
(None,),
|
||||
((None,),),
|
||||
((),),
|
||||
(([()]),),
|
||||
((1, 0),),
|
||||
(((1, "foo"), ["bar", (3, None, 7)]),),
|
||||
([3],),
|
||||
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
||||
([AnObject(3, None, [4, "foo"])],),
|
||||
(SpecialWithKeys(2, 3.),),
|
||||
({"a": 1, "b": 0},),
|
||||
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
(collections.defaultdict(dict,
|
||||
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
||||
)
|
||||
|
||||
|
||||
class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
@ -406,6 +436,101 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
|
||||
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
|
||||
|
||||
@parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES))
|
||||
def testRoundtripWithPath(self, inputs):
|
||||
key_leaves, treedef = tree_util.tree_flatten_with_path(inputs)
|
||||
actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves])
|
||||
self.assertEqual(actual, inputs)
|
||||
|
||||
def testTreeMapWithPath(self):
|
||||
tree = [{i: i for i in range(10)}]
|
||||
all_zeros = tree_util.tree_map_with_path(
|
||||
lambda kp, val: val - kp[1].key + kp[0].idx, tree
|
||||
)
|
||||
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
|
||||
|
||||
def testTreeMapWithPathMultipleTrees(self):
|
||||
tree1 = [AnObject(x=12,
|
||||
y={'cin': [1, 4, 10], 'bar': None},
|
||||
z='constantdef'),
|
||||
5]
|
||||
tree2 = [AnObject(x=2,
|
||||
y={'cin': [2, 2, 2], 'bar': None},
|
||||
z='constantdef'),
|
||||
2]
|
||||
from_two_trees = tree_util.tree_map_with_path(
|
||||
lambda kp, a, b: a + b, tree1, tree2
|
||||
)
|
||||
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
|
||||
self.assertEqual(from_two_trees, from_one_tree)
|
||||
|
||||
def testKeyStr(self):
|
||||
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
|
||||
flattened, _ = tree_util.tree_flatten_with_path(tree1)
|
||||
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
||||
self.assertEqual(
|
||||
strs,
|
||||
[
|
||||
"[0].foo: 12",
|
||||
"[0].bar['cin'][0]: 1",
|
||||
"[0].bar['cin'][1]: 4",
|
||||
"[0].bar['cin'][2]: 10",
|
||||
"[1]: [0 1 2 3 4]",
|
||||
],
|
||||
)
|
||||
|
||||
def testTreeMapWithPathWithIsLeafArgument(self):
|
||||
x = ((1, 2), [3, 4, 5])
|
||||
y = (([3], jnp.array((0))), ([0], 7, [5, 6]))
|
||||
out = tree_util.tree_map_with_path(
|
||||
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
|
||||
is_leaf=lambda n: isinstance(n, list))
|
||||
self.assertEqual(out, (((0, 1, [3]),
|
||||
(0, 2, jnp.array((0)))),
|
||||
(1, [3, 4, 5], ([0], 7, [5, 6]))))
|
||||
|
||||
def testFlattenWithPathWithIsLeafArgument(self):
|
||||
def is_empty(x):
|
||||
try:
|
||||
children, _ = flatten_one_level(x)
|
||||
except ValueError:
|
||||
return True # Cannot flatten x; means it must be a leaf
|
||||
return len(children) == 0
|
||||
|
||||
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
||||
tree1 = {'a': 1,
|
||||
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
||||
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
|
||||
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
||||
self.assertEqual(
|
||||
strs,
|
||||
[
|
||||
"['a']: 1",
|
||||
"['obj']x: EmptyTuple()",
|
||||
"['obj']y: 0",
|
||||
"['sub'][0]: [1 2]",
|
||||
"['sub'][1].foo: ()",
|
||||
"['sub'][1].bar[0]: None",
|
||||
],
|
||||
)
|
||||
|
||||
def testFlattenOneLevel(self):
|
||||
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
||||
tree1 = {'a': 1,
|
||||
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
||||
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
|
||||
self.assertEqual(flatten_one_level(tree1["sub"])[0],
|
||||
tree1["sub"])
|
||||
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
|
||||
[(), [None]])
|
||||
self.assertEqual(flatten_one_level(tree1["obj"])[0],
|
||||
[EmptyTuple(), 0])
|
||||
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
||||
flatten_one_level(1)
|
||||
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
||||
flatten_one_level(jnp.array((1, 2)))
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
@ -485,7 +610,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
def test_different_types(self):
|
||||
e, = prefix_errors((1, 2), [1, 2])
|
||||
expected = ("pytree structure error: different types at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
@ -511,7 +636,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
e, = prefix_errors((1,), (2, 3))
|
||||
expected = ("pytree structure error: different lengths of tuple "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
@ -519,7 +644,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
e, = prefix_errors([1], [2, 3])
|
||||
expected = ("pytree structure error: different lengths of list "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
@ -528,7 +653,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
@ -564,7 +689,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
e, = prefix_errors({1: 2}, {3: 4})
|
||||
expected = ("pytree structure error: different pytree metadata "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
@ -603,7 +728,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
|
||||
e, = prefix_errors({}, {'a': []})
|
||||
expected = ("pytree structure error: different numbers of pytree children "
|
||||
"at key path\n"
|
||||
" in_axes tree root")
|
||||
" in_axes")
|
||||
with self.assertRaisesRegex(ValueError, expected):
|
||||
raise e('in_axes')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user