Merge branch 'google:main' into keypath-log

This commit is contained in:
Ivy Zheng 2023-03-05 12:35:23 -08:00 committed by GitHub
commit 1a4527ed40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 491 additions and 137 deletions

View File

@ -23,7 +23,7 @@ import jax
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax.tree_util import tree_flatten, tree_unflatten, keystr
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
@ -407,7 +407,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
if v in res_vars:
if arg_info is not None:
arg_name, arg_path = arg_info[i]
src = f'from the argument {arg_name}{arg_path.pprint("")}'
src = f'from the argument {arg_name}{keystr(arg_path)}'
else:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))

View File

@ -40,7 +40,7 @@ from jax._src import linear_util as lu
from jax import stages
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
Partial, PyTreeDef, all_leaves)
Partial, PyTreeDef, all_leaves, keystr)
from jax._src import callback as jcb
from jax._src import core
from jax._src import device_array
@ -1815,15 +1815,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
except (TypeError, ValueError):
ba = None
if ba is None:
args_paths = [f'args{p.pprint("")} '
args_paths = [f'args{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
for p, x in _generate_key_paths(args)]
kwargs_paths = [f'kwargs{p.pprint("")} '
kwargs_paths = [f'kwargs{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
for p, x in _generate_key_paths(kwargs)]
key_paths = [*args_paths, *kwargs_paths]
else:
key_paths = [f'argument {name}{p.pprint("")} '
key_paths = [f'argument {name}{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
for name, arg in ba.arguments.items()
for p, x in _generate_key_paths(arg)]

View File

@ -475,10 +475,7 @@ class ArrayImpl(basearray.Array):
self._check_if_deleted()
if self._npy_value is None:
if self.is_fully_replicated:
arr = self._arrays[0] # type: ignore
# copy_to_host_async implemented in c++ only for single device arrays.
if hasattr(arr, "_copy_single_device_array_to_host_async"):
arr._copy_single_device_array_to_host_async() # type: ignore
self._arrays[0].copy_to_host_async()
return
try:
self.addressable_shards[0].replica_id
@ -496,12 +493,7 @@ class ArrayImpl(basearray.Array):
if self._npy_value is None:
if self.is_fully_replicated:
arr = self._arrays[0] # type: ignore
# Conversion to numpy implemented only for single device arrays.
if hasattr(arr, "_single_device_array_to_np_array"):
self._npy_value = arr._single_device_array_to_np_array() # type: ignore
else:
self._npy_value = np.asarray(arr) # type: ignore
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
self._npy_value.flags.writeable = False
return cast(np.ndarray, self._npy_value)

View File

@ -16,6 +16,7 @@ from collections import Counter
import dataclasses
import functools
import math
import warnings
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
@ -263,6 +264,10 @@ class GlobalDeviceArray:
device_buffers: Sequence[DeviceArray],
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
_enable_checks: bool = True):
warnings.warn(
"GlobalDeviceArray has been deprecated. Please migrate to jax.Array. "
"See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration "
"on how to migrate to jax.Array.", DeprecationWarning)
self._global_shape = global_shape
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes

View File

@ -49,7 +49,7 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map
from jax.tree_util import tree_flatten, tree_map, keystr
from jax._src import abstract_arrays
from jax._src import api_util
@ -3058,7 +3058,7 @@ def lower_sharding_computation(
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
if i in kept_var_idx]
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
@ -3249,7 +3249,7 @@ def lower_mesh_computation(
closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)]
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)]
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,

View File

@ -148,10 +148,10 @@ def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name):
arg_list = []
for arg_key, val in args_aug:
ak, *rem_keys = arg_key.keys
ak, *rem_keys = arg_key
if sig is not None:
loc = ''.join(k.pprint() for k in rem_keys)
arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}'
loc = ''.join(str(k) for k in rem_keys)
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
else:
arg_name = ''
da = val.sharding._device_assignment if hasattr(val, 'sharding') else None

View File

@ -13,20 +13,22 @@
# limitations under the License.
import collections
from dataclasses import dataclass
import difflib
import functools
from functools import partial
import operator as op
from typing import (Any, Callable, Dict, Hashable, Iterable, List, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, overload)
import textwrap
from typing import (Any, Callable, Hashable, Iterable, List, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, overload)
import warnings
from jax._src.lib import pytree
from jax._src.util import safe_zip, unzip2
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.util import safe_zip
from jax._src.util import unzip2
traceback_util.register_exclusion(__file__)
T = TypeVar("T")
@ -291,7 +293,6 @@ register_pytree_node(
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index,call-overload]
class _HashableCallableShim:
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
def __init__(self, fun):
@ -398,85 +399,290 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
return result
def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
pytree: A valid pytree node, either built-in or registered via
``register_pytree_node`` or ``register_pytree_with_keys``.
Returns:
A pair of the pytree's flattened children and its hashable metadata.
Raises:
ValueError: If the given pytree is not a built-in or registered container
via ``register_pytree_node`` or ``register_pytree_with_keys``.
"""
handler = _registry.get(type(pytree))
if handler:
children, meta = handler.to_iter(pytree)
return list(children), meta
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
return list(pytree), None
# handle namedtuple as a special case, based on heuristic
return [getattr(pytree, s) for s in pytree._fields], None
else:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Callable[[str], ValueError]]:
return list(_prefix_error(KeyPath(()), prefix_tree, full_tree, is_leaf))
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
class KeyPathEntry(NamedTuple):
# TODO(ivyzheng): Remove old APIs when all users migrated.
class _DeprecatedKeyPathEntry(NamedTuple):
key: Any
def pprint(self) -> str:
assert False # must override
class KeyPath(NamedTuple):
keys: Tuple[KeyPathEntry, ...]
def __add__(self, other):
if isinstance(other, KeyPathEntry):
return KeyPath(self.keys + (other,))
raise TypeError(type(other))
def pprint(self, root: str = ' tree root') -> str:
if not self.keys:
return root
return ''.join(k.pprint() for k in self.keys)
class GetitemKeyPathEntry(KeyPathEntry):
class GetitemKeyPathEntry(_DeprecatedKeyPathEntry):
def pprint(self) -> str:
return f'[{repr(self.key)}]'
def __str__(self):
return self.pprint()
class AttributeKeyPathEntry(KeyPathEntry):
class AttributeKeyPathEntry(_DeprecatedKeyPathEntry):
def pprint(self) -> str:
return f'.{self.key}'
def __str__(self):
return self.pprint()
class FlattenedKeyPathEntry(KeyPathEntry): # fallback
class FlattenedKeyPathEntry(_DeprecatedKeyPathEntry): # fallback
def pprint(self) -> str:
return f'[<flat index {self.key}>]'
def __str__(self):
return self.pprint()
def _child_keys(pytree: Any) -> Sequence[KeyPathEntry]:
assert not treedef_is_strict_leaf(tree_structure(pytree))
handler = _keypath_registry.get(type(pytree))
@dataclass(frozen=True)
class SequenceKey():
idx: int
def __str__(self):
return f'[{repr(self.idx)}]'
@dataclass(frozen=True)
class DictKey():
key: Hashable
def __str__(self):
return f'[{repr(self.key)}]'
@dataclass(frozen=True)
class GetAttrKey():
name: str
def __str__(self):
return f'.{self.name}'
@dataclass(frozen=True)
class FlattenedIndexKey():
key: int
def __str__(self):
return f'[<flat index {self.key}>]'
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = Tuple[KeyEntry, ...]
def keystr(keys: KeyPath):
return ''.join([str(k) for k in keys])
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
def register_keypaths(
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
) -> None:
"""[Deprecated] Register the method to get keypaths for type.
Please use ``register_pytree_with_keys`` instead.
Only works if the type was already registered with ``register_pytree_node``.
"""
warnings.warn(
(
"jax.tree_util.register_keypaths is deprecated, and will be removed"
" in a future release. Please use `register_pytree_with_keys()`"
" instead."
),
category=FutureWarning,
stacklevel=2,
)
_register_keypaths(ty, handler)
def _register_keypaths(
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
) -> None:
def flatten_with_keys(xs):
children, treedef = _registry[ty].to_iter(xs)
return list(zip(handler(xs), children)), treedef
if ty in _registry:
_registry_with_keypaths[ty] = _RegistryWithKeypathsEntry(
flatten_with_keys, _registry[ty].from_iter
)
_registry_with_keypaths = {}
_register_keypaths(
tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
)
_register_keypaths(
list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
)
_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs)))
_register_keypaths(
collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys())
)
_register_keypaths(
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
)
def register_pytree_with_keys(
nodetype: Type[T],
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]):
"""Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to ``register_pytree_node`` that allows
you to access each pytree leaf's key path when flattening and tree-mapping.
Args:
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
each key path and its child, and (2) some hashable auxiliary data to be
stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the
unflattened children. The function should return an instance of
``nodetype``.
"""
def flatten_func(tree):
key_children, treedef = flatten_with_keys(tree)
return [c for _, c in key_children], treedef
register_pytree_node(nodetype, flatten_func, unflatten_func)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
def register_pytree_with_keys_class(cls: U) -> U:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is similar to ``register_pytree_node_class``, but requires a
class that defines how it could be flattened with keys.
It is a thin wrapper around ``register_pytree_with_keys``, and
provides a class-oriented interface::
@register_pytree_with_keys_class
class Special:
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten_with_keys(self):
return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
"""
register_pytree_with_keys(
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten
)
return cls
def tree_flatten_with_path(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> Tuple[List[Tuple[KeyPath, Any]], PyTreeDef]:
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
Args:
tree: a pytree to flatten. If it contains a custom type, it must be
registered with ``register_pytree_with_keys``.
Returns:
A pair which the first element is a list of key-leaf pairs, each of
which contains a leaf and its key path. The second element is a treedef
representing the structure of the flattened tree.
"""
_, tree_def = tree_flatten(tree, is_leaf)
return _generate_key_paths(tree, is_leaf), tree_def
def _generate_key_paths(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
return list(_generate_key_paths_((), tree, is_leaf))
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
def _generate_key_paths_(
key_path: KeyPath,
tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
return
handler = _registry_with_keypaths.get(type(tree))
if handler:
return handler(pytree)
key_children, _ = handler.flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
elif tree is not None: # Some strictly leaf type, like int or numpy array
yield key_path, tree
def tree_map_with_path(f: Callable[..., Any],
tree: Any, *rest: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of ``tree_map`` that can take the key path
of each leaf as input argument as well.
Args:
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
each corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf's key path as the first
positional argument and the leaf itself as the second argument to ``f``.
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
the tuple of values at corresponding nodes in ``rest``.
"""
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
keypath_leaves = list(zip(*keypath_leaves))
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
def _child_keys(pytree: Any) -> KeyPath:
assert not treedef_is_strict_leaf(tree_structure(pytree))
handler = _registry_with_keypaths.get(type(pytree))
if handler:
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
# handle namedtuple as a special case, based on heuristic
return [AttributeKeyPathEntry(s) for s in pytree._fields]
return tuple(GetAttrKey(s) for s in pytree._fields)
else:
num_children = len(treedef_children(tree_structure(pytree)))
return [FlattenedKeyPathEntry(i) for i in range(num_children)]
return tuple(FlattenedIndexKey(i) for i in range(num_children))
_keypath_registry: Dict[Type, Callable[[Any], Sequence[KeyPathEntry]]] = {}
def register_keypaths(ty: Type, handler: Callable[[Any], Sequence[KeyPathEntry]]
) -> None:
_keypath_registry[ty] = handler
register_keypaths(tuple,
lambda tup: [GetitemKeyPathEntry(i) for i in range(len(tup))])
register_keypaths(list,
lambda lst: [GetitemKeyPathEntry(i) for i in range(len(lst))])
register_keypaths(dict,
lambda dct: [GetitemKeyPathEntry(k) for k in sorted(dct)])
def _generate_key_paths(tree: Any) -> List[Tuple[KeyPath, Any]]:
return list(_generate_key_paths_(KeyPath(()), tree))
def _generate_key_paths_(key_path: KeyPath, tree: Any
) -> Iterable[Tuple[KeyPath, Any]]:
if treedef_is_strict_leaf(tree_structure(tree)):
yield key_path, tree
else:
child_keys = _child_keys(tree)
tree_children, _ = flatten_one_level(tree)
for k, c in zip(child_keys, tree_children):
yield from _generate_key_paths_(key_path + k, c)
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
@ -488,7 +694,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
if type(prefix_tree) != type(full_tree):
yield lambda name: ValueError(
"pytree structure error: different types at key path\n"
f" {{name}}{key_path.pprint()}\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"but at the same key path the full pytree has a subtree of different type\n"
@ -510,7 +716,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
ty = type(prefix_tree)
yield lambda name: ValueError(
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
f" {{name}}{key_path.pprint()}\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type "
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
f"has a subtree of the same type but of length {len(full_tree)}."
@ -525,7 +731,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} child keys\n"
@ -549,7 +755,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
prefix=" ")
yield lambda name: ValueError(
"pytree structure error: different pytree metadata at key path\n"
f" {{name}}{key_path.pprint()}\n"
f" {{name}}{keystr(key_path)}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with metadata\n"
@ -567,7 +773,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
("equal pytree nodes gave differing prefix_tree_keys: "
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(key_path + k, t1, t2)
yield from _prefix_error(tuple((*key_path, k)), t1, t2)
# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py

View File

@ -48,7 +48,7 @@ from jax.interpreters import xla
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_leaves)
tree_structure, tree_leaves, keystr)
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
_generate_key_paths, KeyPath)
@ -130,7 +130,7 @@ SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
prefix = 'in' if error_type == SpecErrorType.input else 'out'
msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, "
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
for key, x in _generate_key_paths(specs) if not isinstance(x, P)]
raise TypeError(
f"shard_map {prefix}_specs argument must be a pytree of "
@ -169,15 +169,15 @@ def _spec_rank_error(
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key.keys
extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
arg_key, *_ = fail_key
extra = (f", where {base}[{arg_key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
else:
extra = ""
msgs.append(
f"{prefix}_specs{spec_key.pprint()} is {spec} which has length "
f"{prefix}_specs{keystr(spec_key)} is {spec} which has length "
f"{len(spec)}, but "
f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, "
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
@ -197,9 +197,9 @@ def _spec_divisibility_error(
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if ba is not None:
arg_key, *_ = fail_key.keys
extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
arg_key, *_ = fail_key
extra = (f", where args[{arg_key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
@ -207,8 +207,8 @@ def _spec_divisibility_error(
total = 'total ' if len(ns) > 1 else ''
sz = math.prod(mesh.shape[n] for n in ns)
msgs.append(
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
f"args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
@ -237,14 +237,14 @@ def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
got_rep = ','.join(map(str, rep))
diff = ','.join(map(str, unmentioned - rep))
msgs.append(
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
f"out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axes "
f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, "
f"which is missing the required axes {diff}")
else:
need_rep_, = unmentioned
msgs.append(
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
f"out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs

View File

@ -43,7 +43,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, DebugInfo)
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
KeyPath, _generate_key_paths)
KeyPath, _generate_key_paths, keystr)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache,
@ -1493,7 +1493,7 @@ class DynamicJaxprTracer(core.Tracer):
arg_info = arg_info_all(dbg)
if invar_pos and arg_info:
arg_info = [arg_info[i] for i in invar_pos]
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
arg_names = [f'{name}{keystr(path)}' for name, path in arg_info]
if len(arg_names) == 1:
arg_info_str = f"the argument {arg_names[0]}"
elif len(arg_names) == 2:

View File

@ -90,7 +90,6 @@ from jax._src.interpreters.pxla import (
lower_mesh_computation as lower_mesh_computation,
lower_parallel_callable as lower_parallel_callable,
lower_sharding_computation as lower_sharding_computation,
make_sharded_device_array as make_sharded_device_array,
maybe_extend_axis_env as maybe_extend_axis_env,
mesh_sharding_specs as mesh_sharding_specs,
multi_host_supported_collectives as multi_host_supported_collectives,
@ -132,6 +131,7 @@ from jax._src.interpreters.pxla import (
from jax._src.interpreters.pxla import (
Mesh as _deprecated_Mesh,
PartitionSpec as _deprecated_PartitionSpec,
make_sharded_device_array as _deprecated_make_sharded_device_array,
)
import typing
@ -139,6 +139,7 @@ if typing.TYPE_CHECKING:
from jax._src.interpreters.pxla import (
Mesh as Mesh,
PartitionSpec as PartitionSpec,
make_sharded_device_array as make_sharded_device_array,
)
del typing
@ -149,10 +150,21 @@ _deprecations = {
_deprecated_Mesh,
),
"PartitionSpec": (
("jax.interpreters.pxla.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
(
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
# is the default since November 2022.
"make_sharded_device_array": (
(
"jax.interpreters.pxla.make_sharded_device_array is deprecated as"
" of March 3, 2023. Use jax.make_array_from_single_device_arrays."
),
_deprecated_make_sharded_device_array,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr

View File

@ -56,7 +56,17 @@ from jax._src.tree_util import (
treedef_children as treedef_children,
treedef_is_leaf as treedef_is_leaf,
treedef_tuple as treedef_tuple,
# TODO(ivyzheng): Remove old APIs when all users migrated.
register_keypaths as register_keypaths,
AttributeKeyPathEntry as AttributeKeyPathEntry,
GetitemKeyPathEntry as GetitemKeyPathEntry,
register_pytree_with_keys as register_pytree_with_keys,
register_pytree_with_keys_class as register_pytree_with_keys_class,
tree_map_with_path as tree_map_with_path,
tree_flatten_with_path as tree_flatten_with_path,
keystr as keystr,
SequenceKey as SequenceKey,
DictKey as DictKey,
GetAttrKey as GetAttrKey,
FlattenedIndexKey as FlattenedIndexKey,
)

View File

@ -19,13 +19,14 @@ load(
"//jaxlib:jax.bzl",
"if_windows",
"pybind_extension",
"pytype_library",
)
licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
py_library(
pytype_library(
name = "jaxlib",
srcs = [
"ducc_fft.py",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Tuple
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
@ -34,13 +34,14 @@ _C2C = 0
_C2R = 1
_R2C = 2
def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
fft_lengths: List[int]) -> bytes:
def _ducc_fft_descriptor(
shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]
) -> Tuple[bytes, np.dtype, List[int]]:
n = len(shape)
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:

View File

@ -23,14 +23,14 @@ from .hlo_helpers import custom_call
from jaxlib import xla_client
try:
from .cuda import _linalg as _cuda_linalg
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_linalg = None
try:
from .rocm import _linalg as _hip_linalg
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:

View File

@ -26,14 +26,14 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call
try:
from .cuda import _prng as _cuda_prng
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_prng = None
try:
from .rocm import _prng as _hip_prng
from .rocm import _prng as _hip_prng # pytype: disable=import-error
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:

View File

@ -20,7 +20,7 @@ import numpy as np
from jaxlib import xla_client
try:
from .cuda import _rnn as _rnn
from .cuda import _rnn # pytype: disable=import-error
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
except ImportError:

View File

@ -26,14 +26,14 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call
try:
from .cuda import _blas as _cublas
from .cuda import _blas as _cublas # pytype: disable=import-error
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cublas = None
try:
from .cuda import _solver as _cusolver
from .cuda import _solver as _cusolver # pytype: disable=import-error
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
@ -41,14 +41,14 @@ except ImportError:
try:
from .rocm import _blas as _hipblas
from .rocm import _blas as _hipblas # pytype: disable=import-error
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
_hipblas = None
try:
from .rocm import _solver as _hipsolver
from .rocm import _solver as _hipsolver # pytype: disable=import-error
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:

View File

@ -26,7 +26,7 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call
try:
from .cuda import _sparse as _cusparse
from .cuda import _sparse as _cusparse # pytype: disable=import-error
except ImportError:
_cusparse = None
else:
@ -34,7 +34,7 @@ else:
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
try:
from .rocm import _sparse as _hipsparse
from .rocm import _sparse as _hipsparse # pytype: disable=import-error
except ImportError:
_hipsparse = None
else:

View File

@ -21,7 +21,7 @@ import numpy as np
def custom_call(
call_target_name: str,
call_target_name: Union[str, bytes],
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],

View File

@ -26,5 +26,7 @@ filterwarnings =
default:Error writing persistent compilation cache entry for 'jit__lambda_'
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -3745,7 +3745,7 @@ class PJitErrorTest(jtu.JaxTestCase):
error = re.escape(
"pytree structure error: different lengths of list at "
"key path\n"
" pjit out_shardings tree root\n")
" pjit out_shardings\n")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message

View File

@ -24,7 +24,7 @@ import jax
from jax import tree_util
from jax import flatten_util
from jax._src import test_util as jtu
from jax._src.tree_util import prefix_errors
from jax._src.tree_util import prefix_errors, flatten_one_level
import jax.numpy as jnp
@ -53,8 +53,11 @@ class AnObject:
def __repr__(self):
return f"AnObject({self.x},{self.y},{self.z})"
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
lambda z, xy: AnObject(xy[0], xy[1], z))
tree_util.register_pytree_with_keys(
AnObject,
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
lambda z, xy: AnObject(xy[0], xy[1], z), # unflatten (no key involved)
)
@tree_util.register_pytree_node_class
class Special:
@ -75,6 +78,14 @@ class Special:
def __eq__(self, other):
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
@tree_util.register_pytree_with_keys_class
class SpecialWithKeys(Special):
def tree_flatten_with_keys(self):
return (((tree_util.GetAttrKey('x'), self.x),
(tree_util.GetAttrKey('y'), self.y)), None)
@tree_util.register_pytree_node_class
class FlatCache:
def __init__(self, structured, *, leaves=None, treedef=None):
@ -160,6 +171,25 @@ LEAVES = (
(object(),),
)
# All except those decorated by register_pytree_node_class
TREES_WITH_KEYPATH = (
(None,),
((None,),),
((),),
(([()]),),
((1, 0),),
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
(SpecialWithKeys(2, 3.),),
({"a": 1, "b": 0},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
)
class TreeTest(jtu.JaxTestCase):
@ -406,6 +436,101 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
@parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES))
def testRoundtripWithPath(self, inputs):
key_leaves, treedef = tree_util.tree_flatten_with_path(inputs)
actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves])
self.assertEqual(actual, inputs)
def testTreeMapWithPath(self):
tree = [{i: i for i in range(10)}]
all_zeros = tree_util.tree_map_with_path(
lambda kp, val: val - kp[1].key + kp[0].idx, tree
)
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
def testTreeMapWithPathMultipleTrees(self):
tree1 = [AnObject(x=12,
y={'cin': [1, 4, 10], 'bar': None},
z='constantdef'),
5]
tree2 = [AnObject(x=2,
y={'cin': [2, 2, 2], 'bar': None},
z='constantdef'),
2]
from_two_trees = tree_util.tree_map_with_path(
lambda kp, a, b: a + b, tree1, tree2
)
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)
def testKeyStr(self):
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
flattened, _ = tree_util.tree_flatten_with_path(tree1)
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
self.assertEqual(
strs,
[
"[0].foo: 12",
"[0].bar['cin'][0]: 1",
"[0].bar['cin'][1]: 4",
"[0].bar['cin'][2]: 10",
"[1]: [0 1 2 3 4]",
],
)
def testTreeMapWithPathWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], jnp.array((0))), ([0], 7, [5, 6]))
out = tree_util.tree_map_with_path(
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((0, 1, [3]),
(0, 2, jnp.array((0)))),
(1, [3, 4, 5], ([0], 7, [5, 6]))))
def testFlattenWithPathWithIsLeafArgument(self):
def is_empty(x):
try:
children, _ = flatten_one_level(x)
except ValueError:
return True # Cannot flatten x; means it must be a leaf
return len(children) == 0
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
self.assertEqual(
strs,
[
"['a']: 1",
"['obj']x: EmptyTuple()",
"['obj']y: 0",
"['sub'][0]: [1 2]",
"['sub'][1].foo: ()",
"['sub'][1].bar[0]: None",
],
)
def testFlattenOneLevel(self):
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')}
self.assertEqual(flatten_one_level(tree1["sub"])[0],
tree1["sub"])
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
[(), [None]])
self.assertEqual(flatten_one_level(tree1["obj"])[0],
[EmptyTuple(), 0])
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(1)
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(jnp.array((1, 2)))
class RavelUtilTest(jtu.JaxTestCase):
@ -485,7 +610,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
def test_different_types(self):
e, = prefix_errors((1, 2), [1, 2])
expected = ("pytree structure error: different types at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
@ -511,7 +636,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
e, = prefix_errors((1,), (2, 3))
expected = ("pytree structure error: different lengths of tuple "
"at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
@ -519,7 +644,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
e, = prefix_errors([1], [2, 3])
expected = ("pytree structure error: different lengths of list "
"at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
@ -528,7 +653,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
@ -564,7 +689,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
e, = prefix_errors({1: 2}, {3: 4})
expected = ("pytree structure error: different pytree metadata "
"at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')
@ -603,7 +728,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):
e, = prefix_errors({}, {'a': []})
expected = ("pytree structure error: different numbers of pytree children "
"at key path\n"
" in_axes tree root")
" in_axes")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')