From ad8c39ad7ca35de3fbd2fce361e7a2f4861ecbd7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 4 Mar 2023 00:48:29 +0000 Subject: [PATCH 1/4] Internal change PiperOrigin-RevId: 513953876 --- jax/_src/ad_checkpoint.py | 4 +- jax/_src/api.py | 8 +- jax/_src/global_device_array.py | 5 + jax/_src/interpreters/pxla.py | 6 +- jax/_src/pjit.py | 6 +- jax/_src/tree_util.py | 326 +++++++++++++++++++++++++------ jax/experimental/shard_map.py | 28 +-- jax/interpreters/partial_eval.py | 4 +- jax/interpreters/pxla.py | 34 ++-- jax/tree_util.py | 10 + jaxlib/BUILD | 3 +- jaxlib/ducc_fft.py | 9 +- jaxlib/gpu_linalg.py | 4 +- jaxlib/gpu_prng.py | 4 +- jaxlib/gpu_rnn.py | 2 +- jaxlib/gpu_solver.py | 8 +- jaxlib/gpu_sparse.py | 4 +- jaxlib/hlo_helpers.py | 2 +- pytest.ini | 2 + tests/pjit_test.py | 2 +- tests/tree_util_test.py | 143 +++++++++++++- 21 files changed, 488 insertions(+), 126 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index b50b24906..1dbfedabd 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -23,7 +23,7 @@ import jax from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import xla -from jax.tree_util import tree_flatten, tree_unflatten +from jax.tree_util import tree_flatten, tree_unflatten, keystr from jax._src import ad_util from jax._src import core from jax._src import linear_util as lu @@ -407,7 +407,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: if v in res_vars: if arg_info is not None: arg_name, arg_path = arg_info[i] - src = f'from the argument {arg_name}{arg_path.pprint("")}' + src = f'from the argument {arg_name}{keystr(arg_path)}' else: src = 'from the argument at flattened index {i}' results.append((v.aval, src)) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5e82c0fd5..50c7da1f2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -40,7 +40,7 @@ from jax._src import linear_util as lu from jax import stages from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, - Partial, PyTreeDef, all_leaves) + Partial, PyTreeDef, all_leaves, keystr) from jax._src import callback as jcb from jax._src import core from jax._src import device_array @@ -1815,15 +1815,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name): except (TypeError, ValueError): ba = None if ba is None: - args_paths = [f'args{p.pprint("")} ' + args_paths = [f'args{keystr(p)} ' f'of type {shaped_abstractify(x).str_short()}' for p, x in _generate_key_paths(args)] - kwargs_paths = [f'kwargs{p.pprint("")} ' + kwargs_paths = [f'kwargs{keystr(p)} ' f'of type {shaped_abstractify(x).str_short()}' for p, x in _generate_key_paths(kwargs)] key_paths = [*args_paths, *kwargs_paths] else: - key_paths = [f'argument {name}{p.pprint("")} ' + key_paths = [f'argument {name}{keystr(p)} ' f'of type {shaped_abstractify(x).str_short()}' for name, arg in ba.arguments.items() for p, x in _generate_key_paths(arg)] diff --git a/jax/_src/global_device_array.py b/jax/_src/global_device_array.py index 8a4b3888c..a1c97fd45 100644 --- a/jax/_src/global_device_array.py +++ b/jax/_src/global_device_array.py @@ -16,6 +16,7 @@ from collections import Counter import dataclasses import functools import math +import warnings import numpy as np from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple @@ -263,6 +264,10 @@ class GlobalDeviceArray: device_buffers: Sequence[DeviceArray], _gda_fast_path_args: Optional[_GdaFastPathArgs] = None, _enable_checks: bool = True): + warnings.warn( + "GlobalDeviceArray has been deprecated. Please migrate to jax.Array. " + "See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration " + "on how to migrate to jax.Array.", DeprecationWarning) self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b941a7d38..79d850f65 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -49,7 +49,7 @@ import numpy as np import jax from jax.errors import JAXTypeError from jax.interpreters import partial_eval as pe -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import tree_flatten, tree_map, keystr from jax._src import abstract_arrays from jax._src import api_util @@ -3058,7 +3058,7 @@ def lower_sharding_computation( ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info) arg_names = None if arg_info is None else [ - f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info) + f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info) if i in kept_var_idx] lowering_result = mlir.lower_jaxpr_to_module( module_name, @@ -3249,7 +3249,7 @@ def lower_mesh_computation( closed_jaxpr.effects)) arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info) arg_names = None if arg_info is None else [ - f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)] + f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)] lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c292e1f23..b8a449043 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -148,10 +148,10 @@ def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name): arg_list = [] for arg_key, val in args_aug: - ak, *rem_keys = arg_key.keys + ak, *rem_keys = arg_key if sig is not None: - loc = ''.join(k.pprint() for k in rem_keys) - arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}' + loc = ''.join(str(k) for k in rem_keys) + arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}' else: arg_name = '' da = val.sharding._device_assignment if hasattr(val, 'sharding') else None diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 74b11b230..271c95aee 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -13,20 +13,22 @@ # limitations under the License. import collections +from dataclasses import dataclass import difflib import functools from functools import partial import operator as op -from typing import (Any, Callable, Dict, Hashable, Iterable, List, NamedTuple, - Optional, Sequence, Tuple, Type, TypeVar, overload) import textwrap +from typing import (Any, Callable, Hashable, Iterable, List, NamedTuple, + Optional, Tuple, Type, TypeVar, Union, overload) import warnings -from jax._src.lib import pytree - -from jax._src.util import safe_zip, unzip2 - from jax._src import traceback_util +from jax._src.lib import pytree +from jax._src.util import safe_zip +from jax._src.util import unzip2 + + traceback_util.register_exclusion(__file__) T = TypeVar("T") @@ -291,7 +293,6 @@ register_pytree_node( lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index,call-overload] - class _HashableCallableShim: """Object that delegates __call__, __hash__, and __eq__ to another object.""" def __init__(self, fun): @@ -398,85 +399,290 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any, return result def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]: + """Flatten the given pytree node by one level. + + Args: + pytree: A valid pytree node, either built-in or registered via + ``register_pytree_node`` or ``register_pytree_with_keys``. + + Returns: + A pair of the pytree's flattened children and its hashable metadata. + + Raises: + ValueError: If the given pytree is not a built-in or registered container + via ``register_pytree_node`` or ``register_pytree_with_keys``. + """ handler = _registry.get(type(pytree)) if handler: children, meta = handler.to_iter(pytree) return list(children), meta elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'): - return list(pytree), None + # handle namedtuple as a special case, based on heuristic + return [getattr(pytree, s) for s in pytree._fields], None else: raise ValueError(f"can't tree-flatten type: {type(pytree)}") def prefix_errors(prefix_tree: Any, full_tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None, ) -> List[Callable[[str], ValueError]]: - return list(_prefix_error(KeyPath(()), prefix_tree, full_tree, is_leaf)) + return list(_prefix_error((), prefix_tree, full_tree, is_leaf)) -class KeyPathEntry(NamedTuple): +# TODO(ivyzheng): Remove old APIs when all users migrated. + +class _DeprecatedKeyPathEntry(NamedTuple): key: Any def pprint(self) -> str: assert False # must override -class KeyPath(NamedTuple): - keys: Tuple[KeyPathEntry, ...] - def __add__(self, other): - if isinstance(other, KeyPathEntry): - return KeyPath(self.keys + (other,)) - raise TypeError(type(other)) - def pprint(self, root: str = ' tree root') -> str: - if not self.keys: - return root - return ''.join(k.pprint() for k in self.keys) - -class GetitemKeyPathEntry(KeyPathEntry): +class GetitemKeyPathEntry(_DeprecatedKeyPathEntry): def pprint(self) -> str: return f'[{repr(self.key)}]' + def __str__(self): + return self.pprint() -class AttributeKeyPathEntry(KeyPathEntry): +class AttributeKeyPathEntry(_DeprecatedKeyPathEntry): def pprint(self) -> str: return f'.{self.key}' + def __str__(self): + return self.pprint() -class FlattenedKeyPathEntry(KeyPathEntry): # fallback +class FlattenedKeyPathEntry(_DeprecatedKeyPathEntry): # fallback def pprint(self) -> str: return f'[]' + def __str__(self): + return self.pprint() -def _child_keys(pytree: Any) -> Sequence[KeyPathEntry]: - assert not treedef_is_strict_leaf(tree_structure(pytree)) - handler = _keypath_registry.get(type(pytree)) + +@dataclass(frozen=True) +class SequenceKey(): + idx: int + def __str__(self): + return f'[{repr(self.idx)}]' + +@dataclass(frozen=True) +class DictKey(): + key: Hashable + def __str__(self): + return f'[{repr(self.key)}]' + +@dataclass(frozen=True) +class GetAttrKey(): + name: str + def __str__(self): + return f'.{self.name}' + +@dataclass(frozen=True) +class FlattenedIndexKey(): + key: int + def __str__(self): + return f'[]' + +BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey] + +KeyEntry = TypeVar("KeyEntry", bound=Hashable) +KeyPath = Tuple[KeyEntry, ...] + +def keystr(keys: KeyPath): + return ''.join([str(k) for k in keys]) + + +class _RegistryWithKeypathsEntry(NamedTuple): + flatten_with_keys: Callable[..., Any] + unflatten_func: Callable[..., Any] + + +def register_keypaths( + ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]] +) -> None: + """[Deprecated] Register the method to get keypaths for type. + + Please use ``register_pytree_with_keys`` instead. + + Only works if the type was already registered with ``register_pytree_node``. + """ + warnings.warn( + ( + "jax.tree_util.register_keypaths is deprecated, and will be removed" + " in a future release. Please use `register_pytree_with_keys()`" + " instead." + ), + category=FutureWarning, + stacklevel=2, + ) + _register_keypaths(ty, handler) + + +def _register_keypaths( + ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]] +) -> None: + def flatten_with_keys(xs): + children, treedef = _registry[ty].to_iter(xs) + return list(zip(handler(xs), children)), treedef + if ty in _registry: + _registry_with_keypaths[ty] = _RegistryWithKeypathsEntry( + flatten_with_keys, _registry[ty].from_iter + ) + + +_registry_with_keypaths = {} + +_register_keypaths( + tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) +) +_register_keypaths( + list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) +) +_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs))) + +_register_keypaths( + collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys()) +) + +_register_keypaths( + collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) +) + +def register_pytree_with_keys( + nodetype: Type[T], + flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]], + unflatten_func: Callable[[_AuxData, _Children], T]): + """Extends the set of types that are considered internal nodes in pytrees. + + This is a more powerful alternative to ``register_pytree_node`` that allows + you to access each pytree leaf's key path when flattening and tree-mapping. + + Args: + nodetype: a Python type to treat as an internal pytree node. + flatten_func: a function to be used during flattening, taking a value of + type ``nodetype`` and returning a pair, with (1) an iterable for tuples of + each key path and its child, and (2) some hashable auxiliary data to be + stored in the treedef and to be passed to the ``unflatten_func``. + unflatten_func: a function taking two arguments: the auxiliary data that was + returned by ``flatten_func`` and stored in the treedef, and the + unflattened children. The function should return an instance of + ``nodetype``. + """ + def flatten_func(tree): + key_children, treedef = flatten_with_keys(tree) + return [c for _, c in key_children], treedef + register_pytree_node(nodetype, flatten_func, unflatten_func) + _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( + flatten_with_keys, unflatten_func + ) + +def register_pytree_with_keys_class(cls: U) -> U: + """Extends the set of types that are considered internal nodes in pytrees. + + This function is similar to ``register_pytree_node_class``, but requires a + class that defines how it could be flattened with keys. + + It is a thin wrapper around ``register_pytree_with_keys``, and + provides a class-oriented interface:: + + @register_pytree_with_keys_class + class Special: + def __init__(self, x, y): + self.x = x + self.y = y + def tree_flatten_with_keys(self): + return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + """ + register_pytree_with_keys( + cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten + ) + return cls + + +def tree_flatten_with_path( + tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None +) -> Tuple[List[Tuple[KeyPath, Any]], PyTreeDef]: + """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, it must be + registered with ``register_pytree_with_keys``. + Returns: + A pair which the first element is a list of key-leaf pairs, each of + which contains a leaf and its key path. The second element is a treedef + representing the structure of the flattened tree. + """ + _, tree_def = tree_flatten(tree, is_leaf) + return _generate_key_paths(tree, is_leaf), tree_def + + +def _generate_key_paths( + tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None +) -> List[Tuple[KeyPath, Any]]: + return list(_generate_key_paths_((), tree, is_leaf)) + + +# The overall logic should be same as PyTreeDef::FlattenIntoImpl +def _generate_key_paths_( + key_path: KeyPath, + tree: Any, + is_leaf: Optional[Callable[[Any], bool]] = None, +) -> Iterable[Tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + handler = _registry_with_keypaths.get(type(tree)) if handler: - return handler(pytree) + key_children, _ = handler.flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf) + elif isinstance(tree, tuple) and hasattr(tree, '_fields'): + # handle namedtuple as a special case, based on heuristic + key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] + for k, c in key_children: + yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf) + elif tree is not None: # Some strictly leaf type, like int or numpy array + yield key_path, tree + + +def tree_map_with_path(f: Callable[..., Any], + tree: Any, *rest: Any, + is_leaf: Optional[Callable[[Any], bool]] = None) -> Any: + """Maps a multi-input function over pytree key path and args to produce a new pytree. + + This is a more powerful alternative of ``tree_map`` that can take the key path + of each leaf as input argument as well. + + Args: + f: function that takes ``2 + len(rest)`` arguments, aka. the key path and + each corresponding leaves of the pytrees. + tree: a pytree to be mapped over, with each leaf's key path as the first + positional argument and the leaf itself as the second argument to ``f``. + *rest: a tuple of pytrees, each of which has the same structure as ``tree`` + or has ``tree`` as a prefix. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at + the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is + the tuple of values at corresponding nodes in ``rest``. + """ + + keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] + return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) + + +def _child_keys(pytree: Any) -> KeyPath: + assert not treedef_is_strict_leaf(tree_structure(pytree)) + handler = _registry_with_keypaths.get(type(pytree)) + if handler: + return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0]) elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'): # handle namedtuple as a special case, based on heuristic - return [AttributeKeyPathEntry(s) for s in pytree._fields] + return tuple(GetAttrKey(s) for s in pytree._fields) else: num_children = len(treedef_children(tree_structure(pytree))) - return [FlattenedKeyPathEntry(i) for i in range(num_children)] + return tuple(FlattenedIndexKey(i) for i in range(num_children)) -_keypath_registry: Dict[Type, Callable[[Any], Sequence[KeyPathEntry]]] = {} - -def register_keypaths(ty: Type, handler: Callable[[Any], Sequence[KeyPathEntry]] - ) -> None: - _keypath_registry[ty] = handler - -register_keypaths(tuple, - lambda tup: [GetitemKeyPathEntry(i) for i in range(len(tup))]) -register_keypaths(list, - lambda lst: [GetitemKeyPathEntry(i) for i in range(len(lst))]) -register_keypaths(dict, - lambda dct: [GetitemKeyPathEntry(k) for k in sorted(dct)]) - -def _generate_key_paths(tree: Any) -> List[Tuple[KeyPath, Any]]: - return list(_generate_key_paths_(KeyPath(()), tree)) - -def _generate_key_paths_(key_path: KeyPath, tree: Any - ) -> Iterable[Tuple[KeyPath, Any]]: - if treedef_is_strict_leaf(tree_structure(tree)): - yield key_path, tree - else: - child_keys = _child_keys(tree) - tree_children, _ = flatten_one_level(tree) - for k, c in zip(child_keys, tree_children): - yield from _generate_key_paths_(key_path + k, c) def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None, @@ -488,7 +694,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, if type(prefix_tree) != type(full_tree): yield lambda name: ValueError( "pytree structure error: different types at key path\n" - f" {{name}}{key_path.pprint()}\n" + f" {{name}}{keystr(key_path)}\n" f"At that key path, the prefix pytree {{name}} has a subtree of type\n" f" {type(prefix_tree)}\n" f"but at the same key path the full pytree has a subtree of different type\n" @@ -510,7 +716,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, ty = type(prefix_tree) yield lambda name: ValueError( f"pytree structure error: different lengths of {ty.__name__} at key path\n" - f" {{name}}{key_path.pprint()}\n" + f" {{name}}{keystr(key_path)}\n" f"At that key path, the prefix pytree {{name}} has a subtree of type " f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree " f"has a subtree of the same type but of length {len(full_tree)}." @@ -525,7 +731,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, if len(prefix_tree_children) != len(full_tree_children): yield lambda name: ValueError( "pytree structure error: different numbers of pytree children at key path\n" - f" {{name}}{key_path.pprint()}\n" + f" {{name}}{keystr(key_path)}\n" f"At that key path, the prefix pytree {{name}} has a subtree of type\n" f" {type(prefix_tree)}\n" f"with {len(prefix_tree_children)} child keys\n" @@ -549,7 +755,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, prefix=" ") yield lambda name: ValueError( "pytree structure error: different pytree metadata at key path\n" - f" {{name}}{key_path.pprint()}\n" + f" {{name}}{keystr(key_path)}\n" f"At that key path, the prefix pytree {{name}} has a subtree of type\n" f" {type(prefix_tree)}\n" f"with metadata\n" @@ -567,7 +773,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, ("equal pytree nodes gave differing prefix_tree_keys: " f"{prefix_tree_keys} and {full_tree_keys}") for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): - yield from _prefix_error(key_path + k, t1, t2) + yield from _prefix_error(tuple((*key_path, k)), t1, t2) # TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3d503a9f4..cab022a09 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -48,7 +48,7 @@ from jax.interpreters import xla from jax._src.interpreters import pxla from jax.interpreters import ad from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - tree_structure, tree_leaves) + tree_structure, tree_leaves, keystr) from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, _generate_key_paths, KeyPath) @@ -130,7 +130,7 @@ SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) def _check_specs(error_type: SpecErrorType, specs: Any) -> None: if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return prefix = 'in' if error_type == SpecErrorType.input else 'out' - msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, " + msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " for key, x in _generate_key_paths(specs) if not isinstance(x, P)] raise TypeError( f"shard_map {prefix}_specs argument must be a pytree of " @@ -169,15 +169,15 @@ def _spec_rank_error( msgs = [] for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): if error_type == SpecErrorType.input and ba is not None: - arg_key, *_ = fail_key.keys - extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s " - f"parameter '{list(ba.arguments.keys())[arg_key.key]}',") + arg_key, *_ = fail_key + extra = (f", where {base}[{arg_key}] is bound to {f.__name__}'s " + f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',") else: extra = "" msgs.append( - f"{prefix}_specs{spec_key.pprint()} is {spec} which has length " + f"{prefix}_specs{keystr(spec_key)} is {spec} which has length " f"{len(spec)}, but " - f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, " + f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") assert msgs msg = (f"shard_map applied to the function '{f.__name__}' was given an " @@ -197,9 +197,9 @@ def _spec_divisibility_error( msgs = [] for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): if ba is not None: - arg_key, *_ = fail_key.keys - extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s " - f"parameter '{list(ba.arguments.keys())[arg_key.key]}',") + arg_key, *_ = fail_key + extra = (f", where args[{arg_key}] is bound to {f.__name__}'s " + f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',") names = _canonicalize_spec(spec) for d, ns in names.items(): if aval.shape[d] % math.prod(mesh.shape[n] for n in ns): @@ -207,8 +207,8 @@ def _spec_divisibility_error( total = 'total ' if len(ns) > 1 else '' sz = math.prod(mesh.shape[n] for n in ns) msgs.append( - f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} " - f"corresponds to in_specs{spec_key.pprint()} of value {spec}, " + f"args{keystr(fail_key)} of shape {aval.str_short()}{extra} " + f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " f"{aval.shape[d]}") @@ -237,14 +237,14 @@ def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, got_rep = ','.join(map(str, rep)) diff = ','.join(map(str, unmentioned - rep)) msgs.append( - f"out_specs{spec_key.pprint()} is {spec} which implies that the " + f"out_specs{keystr(spec_key)} is {spec} which implies that the " f"corresponding output value is replicated across mesh axes " f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " f"which is missing the required axes {diff}") else: need_rep_, = unmentioned msgs.append( - f"out_specs{spec_key.pprint()} is {spec} which implies that the " + f"out_specs{keystr(spec_key)} is {spec} which implies that the " f"corresponding output value is replicated across mesh axis " f"'{need_rep_}', but could not infer replication over any axes") assert msgs diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index f73ba75ca..b47527090 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -43,7 +43,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, DebugInfo) from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, - KeyPath, _generate_key_paths) + KeyPath, _generate_key_paths, keystr) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, @@ -1493,7 +1493,7 @@ class DynamicJaxprTracer(core.Tracer): arg_info = arg_info_all(dbg) if invar_pos and arg_info: arg_info = [arg_info[i] for i in invar_pos] - arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info] + arg_names = [f'{name}{keystr(path)}' for name, path in arg_info] if len(arg_names) == 1: arg_info_str = f"the argument {arg_names[0]}" elif len(arg_names) == 2: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index f1f96e0c6..47c45af6c 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -90,7 +90,6 @@ from jax._src.interpreters.pxla import ( lower_mesh_computation as lower_mesh_computation, lower_parallel_callable as lower_parallel_callable, lower_sharding_computation as lower_sharding_computation, - make_sharded_device_array as make_sharded_device_array, maybe_extend_axis_env as maybe_extend_axis_env, mesh_sharding_specs as mesh_sharding_specs, multi_host_supported_collectives as multi_host_supported_collectives, @@ -132,6 +131,7 @@ from jax._src.interpreters.pxla import ( from jax._src.interpreters.pxla import ( Mesh as _deprecated_Mesh, PartitionSpec as _deprecated_PartitionSpec, + make_sharded_device_array as _deprecated_make_sharded_device_array, ) import typing @@ -139,20 +139,32 @@ if typing.TYPE_CHECKING: from jax._src.interpreters.pxla import ( Mesh as Mesh, PartitionSpec as PartitionSpec, + make_sharded_device_array as make_sharded_device_array, ) del typing _deprecations = { - # Added Feb 8, 2023: - "Mesh": ( - "jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.", - _deprecated_Mesh, - ), - "PartitionSpec": ( - ("jax.interpreters.pxla.PartitionSpec is deprecated. Use " - "jax.sharding.PartitionSpec."), - _deprecated_PartitionSpec, - ), + # Added Feb 8, 2023: + "Mesh": ( + "jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.", + _deprecated_Mesh, + ), + "PartitionSpec": ( + ( + "jax.interpreters.pxla.PartitionSpec is deprecated. Use " + "jax.sharding.PartitionSpec." + ), + _deprecated_PartitionSpec, + ), + # make_sharded_device_array is deprecated as of March 3, 2023. jax.Array + # is the default since November 2022. + "make_sharded_device_array": ( + ( + "jax.interpreters.pxla.make_sharded_device_array is deprecated as" + " of March 3, 2023. Use jax.make_array_from_single_device_arrays." + ), + _deprecated_make_sharded_device_array, + ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/tree_util.py b/jax/tree_util.py index f00f8d1d0..af265c0a9 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -56,7 +56,17 @@ from jax._src.tree_util import ( treedef_children as treedef_children, treedef_is_leaf as treedef_is_leaf, treedef_tuple as treedef_tuple, + # TODO(ivyzheng): Remove old APIs when all users migrated. register_keypaths as register_keypaths, AttributeKeyPathEntry as AttributeKeyPathEntry, GetitemKeyPathEntry as GetitemKeyPathEntry, + register_pytree_with_keys as register_pytree_with_keys, + register_pytree_with_keys_class as register_pytree_with_keys_class, + tree_map_with_path as tree_map_with_path, + tree_flatten_with_path as tree_flatten_with_path, + keystr as keystr, + SequenceKey as SequenceKey, + DictKey as DictKey, + GetAttrKey as GetAttrKey, + FlattenedIndexKey as FlattenedIndexKey, ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 764583e35..5f4e80f96 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -19,13 +19,14 @@ load( "//jaxlib:jax.bzl", "if_windows", "pybind_extension", + "pytype_library", ) licenses(["notice"]) package(default_visibility = ["//:__subpackages__"]) -py_library( +pytype_library( name = "jaxlib", srcs = [ "ducc_fft.py", diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index acdbf087d..75e2207a1 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Tuple import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo @@ -34,13 +34,14 @@ _C2C = 0 _C2R = 1 _R2C = 2 -def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType, - fft_lengths: List[int]) -> bytes: + +def _ducc_fft_descriptor( + shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int] +) -> Tuple[bytes, np.dtype, List[int]]: n = len(shape) assert len(fft_lengths) >= 1 assert len(fft_lengths) <= n, (fft_lengths, n) - forward = fft_type in (FftType.FFT, FftType.RFFT) is_double = np.finfo(dtype).dtype == np.float64 if fft_type == FftType.RFFT: diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 575a80852..2b7310d56 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -23,14 +23,14 @@ from .hlo_helpers import custom_call from jaxlib import xla_client try: - from .cuda import _linalg as _cuda_linalg + from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error for _name, _value in _cuda_linalg.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cuda_linalg = None try: - from .rocm import _linalg as _hip_linalg + from .rocm import _linalg as _hip_linalg # pytype: disable=import-error for _name, _value in _hip_linalg.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index de9175e60..52b970278 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -26,14 +26,14 @@ from jaxlib import xla_client from .hlo_helpers import custom_call try: - from .cuda import _prng as _cuda_prng + from .cuda import _prng as _cuda_prng # pytype: disable=import-error for _name, _value in _cuda_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cuda_prng = None try: - from .rocm import _prng as _hip_prng + from .rocm import _prng as _hip_prng # pytype: disable=import-error for _name, _value in _hip_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index f647d5ba5..51b196ad9 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -20,7 +20,7 @@ import numpy as np from jaxlib import xla_client try: - from .cuda import _rnn as _rnn + from .cuda import _rnn # pytype: disable=import-error for _name, _value in _rnn.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform='CUDA') except ImportError: diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 14db64b50..30fc096bc 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -26,14 +26,14 @@ from jaxlib import xla_client from .hlo_helpers import custom_call try: - from .cuda import _blas as _cublas + from .cuda import _blas as _cublas # pytype: disable=import-error for _name, _value in _cublas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cublas = None try: - from .cuda import _solver as _cusolver + from .cuda import _solver as _cusolver # pytype: disable=import-error for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: @@ -41,14 +41,14 @@ except ImportError: try: - from .rocm import _blas as _hipblas + from .rocm import _blas as _hipblas # pytype: disable=import-error for _name, _value in _hipblas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: _hipblas = None try: - from .rocm import _solver as _hipsolver + from .rocm import _solver as _hipsolver # pytype: disable=import-error for _name, _value in _hipsolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index f75a51445..791f57308 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -26,7 +26,7 @@ from jaxlib import xla_client from .hlo_helpers import custom_call try: - from .cuda import _sparse as _cusparse + from .cuda import _sparse as _cusparse # pytype: disable=import-error except ImportError: _cusparse = None else: @@ -34,7 +34,7 @@ else: xla_client.register_custom_call_target(_name, _value, platform="CUDA") try: - from .rocm import _sparse as _hipsparse + from .rocm import _sparse as _hipsparse # pytype: disable=import-error except ImportError: _hipsparse = None else: diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 38100bf1f..3efa6a68b 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -21,7 +21,7 @@ import numpy as np def custom_call( - call_target_name: str, + call_target_name: Union[str, bytes], out_types: Sequence[ir.Type], operands: Sequence[ir.Value], operand_layouts: Sequence[Sequence[int]], diff --git a/pytest.ini b/pytest.ini index 9beeeeb36..5737c6d29 100644 --- a/pytest.ini +++ b/pytest.ini @@ -26,5 +26,7 @@ filterwarnings = default:Error writing persistent compilation cache entry for 'jit__lambda_' ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning ignore:backend and device argument on jit is deprecated.*:DeprecationWarning + ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning + ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning doctest_optionflags = NUMBER NORMALIZE_WHITESPACE addopts = --doctest-glob="*.rst" diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fa4bd37b7..ca2586e0a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3745,7 +3745,7 @@ class PJitErrorTest(jtu.JaxTestCase): error = re.escape( "pytree structure error: different lengths of list at " "key path\n" - " pjit out_shardings tree root\n") + " pjit out_shardings\n") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 890a834fc..49f784c91 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,7 +24,7 @@ import jax from jax import tree_util from jax import flatten_util from jax._src import test_util as jtu -from jax._src.tree_util import prefix_errors +from jax._src.tree_util import prefix_errors, flatten_one_level import jax.numpy as jnp @@ -53,8 +53,11 @@ class AnObject: def __repr__(self): return f"AnObject({self.x},{self.y},{self.z})" -tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z), - lambda z, xy: AnObject(xy[0], xy[1], z)) +tree_util.register_pytree_with_keys( + AnObject, + lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys + lambda z, xy: AnObject(xy[0], xy[1], z), # unflatten (no key involved) +) @tree_util.register_pytree_node_class class Special: @@ -75,6 +78,14 @@ class Special: def __eq__(self, other): return type(self) is type(other) and (self.x, self.y) == (other.x, other.y) + +@tree_util.register_pytree_with_keys_class +class SpecialWithKeys(Special): + def tree_flatten_with_keys(self): + return (((tree_util.GetAttrKey('x'), self.x), + (tree_util.GetAttrKey('y'), self.y)), None) + + @tree_util.register_pytree_node_class class FlatCache: def __init__(self, structured, *, leaves=None, treedef=None): @@ -160,6 +171,25 @@ LEAVES = ( (object(),), ) +# All except those decorated by register_pytree_node_class +TREES_WITH_KEYPATH = ( + (None,), + ((None,),), + ((),), + (([()]),), + ((1, 0),), + (((1, "foo"), ["bar", (3, None, 7)]),), + ([3],), + ([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],), + ([AnObject(3, None, [4, "foo"])],), + (SpecialWithKeys(2, 3.),), + ({"a": 1, "b": 0},), + (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), + (collections.defaultdict(dict, + [("foo", 34), ("baz", 101), ("something", -42)]),), + (ANamedTupleSubclass(foo="hello", bar=3.5),), +) + class TreeTest(jtu.JaxTestCase): @@ -406,6 +436,101 @@ class TreeTest(jtu.JaxTestCase): self.assertEqual(nodes_visited, [(None, None), (None, None, None)]) self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]]) + @parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES)) + def testRoundtripWithPath(self, inputs): + key_leaves, treedef = tree_util.tree_flatten_with_path(inputs) + actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves]) + self.assertEqual(actual, inputs) + + def testTreeMapWithPath(self): + tree = [{i: i for i in range(10)}] + all_zeros = tree_util.tree_map_with_path( + lambda kp, val: val - kp[1].key + kp[0].idx, tree + ) + self.assertEqual(all_zeros, [{i: 0 for i in range(10)}]) + + def testTreeMapWithPathMultipleTrees(self): + tree1 = [AnObject(x=12, + y={'cin': [1, 4, 10], 'bar': None}, + z='constantdef'), + 5] + tree2 = [AnObject(x=2, + y={'cin': [2, 2, 2], 'bar': None}, + z='constantdef'), + 2] + from_two_trees = tree_util.tree_map_with_path( + lambda kp, a, b: a + b, tree1, tree2 + ) + from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1) + self.assertEqual(from_two_trees, from_one_tree) + + def testKeyStr(self): + tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)] + flattened, _ = tree_util.tree_flatten_with_path(tree1) + strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] + self.assertEqual( + strs, + [ + "[0].foo: 12", + "[0].bar['cin'][0]: 1", + "[0].bar['cin'][1]: 4", + "[0].bar['cin'][2]: 10", + "[1]: [0 1 2 3 4]", + ], + ) + + def testTreeMapWithPathWithIsLeafArgument(self): + x = ((1, 2), [3, 4, 5]) + y = (([3], jnp.array((0))), ([0], 7, [5, 6])) + out = tree_util.tree_map_with_path( + lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y, + is_leaf=lambda n: isinstance(n, list)) + self.assertEqual(out, (((0, 1, [3]), + (0, 2, jnp.array((0)))), + (1, [3, 4, 5], ([0], 7, [5, 6])))) + + def testFlattenWithPathWithIsLeafArgument(self): + def is_empty(x): + try: + children, _ = flatten_one_level(x) + except ValueError: + return True # Cannot flatten x; means it must be a leaf + return len(children) == 0 + + EmptyTuple = collections.namedtuple("EmptyTuple", ()) + tree1 = {'a': 1, + 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], + 'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')} + flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) + strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] + self.assertEqual( + strs, + [ + "['a']: 1", + "['obj']x: EmptyTuple()", + "['obj']y: 0", + "['sub'][0]: [1 2]", + "['sub'][1].foo: ()", + "['sub'][1].bar[0]: None", + ], + ) + + def testFlattenOneLevel(self): + EmptyTuple = collections.namedtuple("EmptyTuple", ()) + tree1 = {'a': 1, + 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], + 'obj': AnObject(x=EmptyTuple(), y=0, z='constantdef')} + self.assertEqual(flatten_one_level(tree1["sub"])[0], + tree1["sub"]) + self.assertEqual(flatten_one_level(tree1["sub"][1])[0], + [(), [None]]) + self.assertEqual(flatten_one_level(tree1["obj"])[0], + [EmptyTuple(), 0]) + with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): + flatten_one_level(1) + with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): + flatten_one_level(jnp.array((1, 2))) + class RavelUtilTest(jtu.JaxTestCase): @@ -485,7 +610,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): def test_different_types(self): e, = prefix_errors((1, 2), [1, 2]) expected = ("pytree structure error: different types at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') @@ -511,7 +636,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): e, = prefix_errors((1,), (2, 3)) expected = ("pytree structure error: different lengths of tuple " "at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') @@ -519,7 +644,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): e, = prefix_errors([1], [2, 3]) expected = ("pytree structure error: different lengths of list " "at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') @@ -528,7 +653,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3}) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') @@ -564,7 +689,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): e, = prefix_errors({1: 2}, {3: 4}) expected = ("pytree structure error: different pytree metadata " "at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') @@ -603,7 +728,7 @@ class TreePrefixErrorsTest(jtu.JaxTestCase): e, = prefix_errors({}, {'a': []}) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" - " in_axes tree root") + " in_axes") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes') From f4ed8b634f842beffbf5a81e5cf1be7883d77ae9 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 1 Mar 2023 15:24:39 +0100 Subject: [PATCH 2/4] [shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program The use case for call_tf with shape polymorphism is when we have a JAX program that calls into TF function, and we want to serialize the JAX program with some shapes unknown. Previously this use case did not work, except in the special case when the output shape of the called TF function returns statically known shapes. The idea is that we allow the user of call_tf to specify the output shape. This can be done even in presence of shape polymorphism, by writing the output shape as an expression in terms of the input shapes. This is what other JAX primitives do, e.g., concat, so we are simply enabling call_tf to get the same behavior. This change should be enough for old-style jax2tf, but will require more work for native serialization. We also removed some old code that was trying to workaround some limitations in shape inference in TF. I think that those workarounds are ugly, and I am prepared to give error messages rather than keep that code. So far no tests fail. --- CHANGELOG.md | 5 + jax/experimental/jax2tf/README.md | 83 +++++---- jax/experimental/jax2tf/call_tf.py | 169 +++++++++++------- jax/experimental/jax2tf/jax2tf.py | 3 +- jax/experimental/jax2tf/tests/call_tf_test.py | 166 +++++++++++++---- 5 files changed, 298 insertions(+), 128 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f82fecca1..71891e99d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.6 +* Changes + * {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`) + that can be used to declare the output shape and type of the result. This enables + {func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`). + ## jaxlib 0.4.6 ## jax 0.4.5 (Mar 2, 2023) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 4a40c8c5f..a49be441c 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -1229,6 +1229,57 @@ DeviceArray data or for np.ndarray that are aligned on 16-byte boundaries) and on GPU (for DeviceArray). The zero-copy does not yet work on TPU. +`call_tf` works even with shape polymorphism, but in that case +the user must pass the `output_shape_dtype` parameter to `call_tf` to declare +the expected output shapes. This allows JAX tracing to know the shape and +dtype of the results so that it can continue tracing the rest of the program. +When `output_shape_dtype` is not given (the default case), `call_tf` will +form a `tf.Graph` for the called TF function and will use the inferred +type and shape. However, in presence of dynamic shape the inferred TF +type will contain `None` for the dynamic dimensions, which is not enough +information for JAX shape polymorphism. + +For example: + +```python +def fun_jax(x): + y_shape = (x.shape[0] * 2, y.shape[1:]) + y = jax2tf.call_tf( + lambda x: tf.concat([x, x], axis=0), + output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x) + # JAX will know the y.shape + return jnp.ones(y.shape, dtype=y.dtype) + y + +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +An even simpler example for a function that returns the same shape as the input: + +```python +def fun_jax(x): + return jax2tf.call_tf(tf.math.sin, + output_shape_dtype=x) + )(x) + +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +If all the output shapes of the TF function are static, JAX does not need the +`output_shape_dtype` argument: + +```python +def fun_tf(x): + return tf.math.reduce_sum(tf.math.sin(x)) + +def fun_jax(x): + return jax2tf.call_tf(fun_tf)(x) + +# The following will not throw an error because the output shape of fun_tf is static. +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +The shape polymorphism support for `call_tf` does not yet work for native lowering. + ### Limitations of call_tf The TF function must be compileable (`tf.function(func, jit_compile=True)`) @@ -1312,38 +1363,6 @@ JAX computation runs on TPU. This will fail if the computation captures variables on some other devices. It is best to use ``call_tf`` with TF functions that do not capture variables. -A TF function wrapped with `call_tf` cannot be applied to inputs whose -shapes are not constants, unless all the output shapes of the TF function -are static. The may arise when you try to apply `jax2tf.convert` with -polymorphic shapes on the result of `call_tf`: - -```python -def fun_jax(x): - return jax2tf.call_tf(tf.math.sin)(x) - -# The following will throw an error. -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` - -This is unsatisfying, because the result of the above conversion -could be simply `tf.math.sin`, which is batch polymorphic. But -JAX cannot keep track of shapes through a `call_tf` call, and it -cannot be sure that the shape-polymorphic conversion is safe. - -If all the output shapes of the TF function are static, JAX does not need to -keep track of shapes after a `call_tf` call, hence allows shape-polymorphic -inputs in such cases: - -```python -def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - -def fun_jax(x): - return jax2tf.call_tf(fun_tf)(x) - -# The following will not throw an error because the output shape of fun_tf is static. -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` # Misc notes diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 78773021e..ad8609ddc 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -55,13 +55,15 @@ map = util.safe_map zip = util.safe_zip TfConcreteFunction = Any +TfVal = jax2tf_internal.TfVal # The platforms for which to use DLPack to avoid copying (only works on GPU # and CPU at the moment, and only for DeviceArray). For CPU we don't need # DLPack, if we are careful. _DLPACK_PLATFORMS = ("gpu",) -def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: +def call_tf(callable_tf: Callable, has_side_effects=True, + output_shape_dtype=None) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``callable_tf`` will be called with TensorFlow-compatible arguments ( @@ -90,6 +92,14 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: has_side_effects: if True then it ensures that instances of this primitive are not removed or replicated by JAX optimizations such as dead-code elimination. + output_shape_dtype: An optional declaration of the expected shapes and dtypes + from the called TensorFlow function. If given it will be used during JAX + tracing to form the abstract values of the results of the `call_tf`. If + not given then we form a `tf.Graph` for the called TensorFlow function and + we use the TensorFlow-inferred shapes and types. Must be a pytree matching the + structure of the nested structure returned from the TensorFlow function, + containing objects with `.shape` and `.dtype` attributes, + e.g., `jax.ShapeDtypeStruct` or `jax.Array`. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with @@ -113,67 +123,58 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) a_tf_shape = [ - d if core.is_constant_dim(d) else None for d in a_jax.shape - ] + d if core.is_constant_dim(d) else None for d in a_jax.shape] return tf.TensorSpec(a_tf_shape, a_tf_dtype) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) - def check_tf_result(r_tf): - # Check that the TF function returns values of expected types. This - # improves error reporting, preventing hard-to-diagnose errors downstream - try: - jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) - except Exception as e: - msg = ("The called TF function returns a result that is not " - f"convertible to JAX: {r_tf}.") - raise ValueError(msg) from e + if output_shape_dtype is not None: + output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype) + output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat) + else: + output_avals, output_shape_dtype_tree = None, None res_treedef = None # We'll store here the result treedef res_tf_flat = None # For error reporting # The function below will be called at least once, either in eager - # or in graph mode. + # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf() def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: args_tf = args_treedef.unflatten(args_tf_flat) res_tf = callable_tf(*args_tf) nonlocal res_treedef, res_tf_flat res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf) - for r_tf in res_tf_flat: - check_tf_result(r_tf) - assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}" + assert res_treedef is None or res_treedef == res_treedef_now, ( + f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}") res_treedef = res_treedef_now - return res_tf_flat + if output_avals is not None: + if res_treedef != output_shape_dtype_tree: + raise ValueError( + "The pytree of the TensorFlow function results does not match the " + "pytree of the declared output_shape_dtype:\n" + f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}") + assert len(output_avals) == len(res_tf_flat) + + checked_res_tf_flat = [ + check_tf_result(i, r_tf, r_aval) + for i, (r_tf, r_aval) in enumerate( + zip(res_tf_flat, + (output_avals if output_avals is not None + else (None,) * len(res_tf_flat))))] + return checked_res_tf_flat # Prepare a tf.function ahead of time, to cache the concrete functions. This # won't be used in op-by-op execution mode. function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True) - input_shapes_tf = [s.shape for s in args_flat_sig_tf] - output_shapes_tf = _get_concrete_function_tf( - function_flat_tf, args_flat_sig_tf - ).output_shapes - - if not all(s.is_fully_defined() for s in input_shapes_tf) and not all( - s.is_fully_defined() for s in output_shapes_tf - ): - for a_jax, a_tf_shape in zip(args_flat_jax, input_shapes_tf): - if not a_tf_shape.is_fully_defined(): - msg = ( - "call_tf cannot be applied to shape-polymorphic arguments unless" - " all the output shapes are static. Found argument shape:" - f" {a_jax.shape}. See" - " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" - " for a discussion." - ) - raise ValueError(msg) - res_jax_flat = call_tf_p.bind( *args_flat_jax, # Carry the actual function such that op-by-op call can call in TF eager mode. callable_flat_tf=callable_flat_tf, function_flat_tf=function_flat_tf, args_flat_sig_tf=args_flat_sig_tf, + output_avals=output_avals, has_side_effects=has_side_effects) + # We must have called callable_flat_tf by nοw assert res_treedef is not None # Sometimes, in compiled mode, we get a different number of results than we # got when tracing the TF function (and building the res_treedef). This @@ -248,6 +249,44 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: return util.wraps(callable_tf)(make_call) +def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal: + # Check that the TF function returns values of expected types. This + # improves error reporting, preventing hard-to-diagnose errors downstream + try: + jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) + except Exception as e: + msg = ("The called TF function returns a result that is not " + f"convertible to JAX: {r_tf}.") + raise ValueError(msg) from e + + if r_aval is None: + return r_tf + # We convert to TF type, and canonicalize to 32-bit if necessary + r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype) + # Checking shapes is trickier in presence of dynamic shapes. I wish we could + # check at runtime that the returned shape matches the declared shape. I wish + # that tf.ensure_shape did this, but it can only take shapes that contain None + # not computed shapes. However, in eager mode we should be able to resolve + # the declared shapes to constants and we get better checking. + if tf.executing_eagerly(): + r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape) + else: + r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval) + # We do as much checking as we can here, instead of relying on tf.ensure_shape + # because the latter gives different errors in eager vs. compiled mode. + if (r_tf.dtype != r_aval_dtype_tf or + len(r_tf.shape) != len(r_aval_shape_tf) or + any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d + for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))): + msg = ("The shapes or dtypes returned by the TensorFlow function " + "do not match the declared output_shape_dtype:\n" + f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]") + raise ValueError(msg) + # At this point tf.ensure_shape does not do much, it should never throw an + # error, albeit it may refine the shape a bit. + return tf.ensure_shape(r_tf, r_aval_shape_tf) + + call_tf_p = core.Primitive("call_tf") call_tf_p.multiple_results = True @@ -309,39 +348,43 @@ effects.remat_allowed_effects.add_type(CallTfEffect) effects.custom_derivatives_allowed_effects.add_type(CallTfEffect) -def _call_tf_abstract_eval(*_, +def _call_tf_abstract_eval(*args_flat_avals, function_flat_tf, args_flat_sig_tf, - has_side_effects, **__): + has_side_effects, + output_avals, **__): # Called only when we form a Jaxpr, i.e., under jit, scan, etc. + effects = {call_tf_effect} if has_side_effects else set() + # If not output_avals is given, then we ask TF to infer the output shapes. + # We call this even if output_avals is given because it will ensure that + # callable_flat_tf is called. Since _get_concrete_function_tf is cached + # there is a small cost of calling it more often than needed. concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf) + if output_avals is not None: + return output_avals, effects + def is_fully_known_shape(s): return s.rank is not None and all([d is not None for d in s]) - effects = {call_tf_effect} if has_side_effects else set() - if all([is_fully_known_shape(s) - for s in concrete_function_flat_tf.output_shapes]): - return ( - tuple([ - # We convert to JAX type, and canonicalize to 32-bit if necessary - core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) - for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, - concrete_function_flat_tf.output_shapes) - ]), - effects) - # There are some cases when TF shape inference is not powerful enough to - # figure out the output shapes (e.g., b/128924522), even in situations where - # XLA can compile the code, from which we can get the shapes. - - # We use the "cpu" as the platform, since JAX abstract eval is not platform - # specific; the "cpu" backend is always available and for abstract evaluation - # it should not matter which platform we use. - _, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, - "CPU") - return tuple(result_avals), effects + if all(is_fully_known_shape(s) + for s in concrete_function_flat_tf.output_shapes): + avals_from_tf = tuple( + # We convert to JAX type, and canonicalize to 32-bit if necessary + core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) + for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, + concrete_function_flat_tf.output_shapes)) + return avals_from_tf, effects + else: + msg = "call_tf cannot call functions whose output has dynamic shape. " + if any(not core.is_constant_shape(a.shape) for a in args_flat_avals): + msg += ("The function is called with shape-polymorphic inputs. Consider " + "using the `output_shape_dtype` argument to call_tf. ") + msg += ("\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion.") + raise ValueError(msg) call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) @@ -372,6 +415,11 @@ def _code_generator_and_avals( ) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], Sequence[ir.Value]]], Sequence[core.ShapedArray]]: + # TODO(necula): we have refactored the code to not need to lower the code + # just in order to get the avals, so in fact the returned avals from this + # function are never used. We keep it here for now in case we detect + # a regressions, but if not we should simplify this function. + # Returns and caches a code generator (taking a builder and the # XlaOps for the arguments) and a sequence of result abstract shapes. @@ -478,8 +526,7 @@ def _register_call_lowering(platform): for platform in ("cpu", "cuda", "tpu"): _register_call_lowering(platform) -# Support the call_tf under jax2tf.convert -TfVal = jax2tf_internal.TfVal +# Support the call_tf under jax2tf.convert in eager mode def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 4116ecbc2..c9b4efb35 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -335,8 +335,7 @@ def convert(fun_jax: Callable, _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() # TODO: is there a better way to check if we are inside a transformation? - if not core.trace_state_clean( - ) and not _thread_local_state.inside_call_tf: + if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 370e3e925..f4ecb8f89 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -44,6 +44,12 @@ def _maybe_jit(with_jit: bool, func: Callable) -> Callable: else: return func +def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable: + if with_jit: + return tf.function(func, autograph=False, jit_compile=True) + else: + return func + def _named_test(**kwargs): return dict(kwargs, testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())])) @@ -53,8 +59,7 @@ _parameterized_jit = parameterized.named_parameters( for with_jit in [True, False]) _call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions" -_call_tf_dynamic_shape_error = "Compiled TensorFlow function has dynamic output shape.* call_tf can used in a staged context .* only with compileable functions" - +_call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape" class CallTfTest(tf_test_util.JaxToTfTestCase): @@ -171,8 +176,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): x = np.array([True, False], dtype=np.bool_) self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): jax.jit(f_jax)(x) def test_error_bad_result_tensorarray(self): @@ -569,9 +573,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(x[0:x[1]], res1) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex( - ValueError, "Compiled TensorFlow function has dynamic output shape" - ): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) @@ -1099,29 +1101,143 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(expected, res1, check_dtypes=False) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) # TODO(necula): this should work in op-by-op mode, but it fails because # jax2tf.convert does abstract evaluation. - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x) - def test_shape_polymorphism_error(self): - x = np.array([.7, .8], dtype=np.float32) + @_parameterized_jit + def test_shape_poly_static_output_shape(self, with_jit=True): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([0.7, 0.8], dtype=np.float32) + def fun_tf(x): - return tf.math.sin(x) + return tf.math.reduce_sum(tf.math.sin(x)) fun_jax = jax2tf.call_tf(fun_tf) + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + self.assertAllClose(fun_tf(x), fun_tf_rt(x)) + + @_parameterized_jit + def test_shape_poly(self, with_jit=False): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + y = jax2tf.call_tf(tf.math.sin, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x) + z = jnp.cos(y) + w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0), + output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z) + assert w.shape[0] == 2 * x.shape[0] + return w + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + res_tf = fun_tf_rt(x) + self.assertAllClose(fun_jax(x), res_tf) + + @_parameterized_jit + def test_shape_poly_pytree_result(self, with_jit=True): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + # Returns a tuple + y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)), + output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x) + assert y[0].shape[0] == x.shape[0] + assert y[1].shape[0] == 2 * x.shape[0] + return y + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + res_tf = fun_tf_rt(x) + self.assertAllClose(fun_jax(x), res_tf) + + @_parameterized_jit + def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True): + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + return jax2tf.call_tf(tf.math.sin)(x) + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + fun_tf_rt(x) + + @_parameterized_jit + def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False): + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + return jax2tf.call_tf(tf.math.sin, + output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape, x.dtype)))(x) + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) with self.assertRaisesRegex( - ValueError, "call_tf cannot be applied to shape-polymorphic arguments" - ): + ValueError, + "The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"): + fun_tf_rt(x) + + @parameterized.named_parameters( + _named_test(with_jit=with_jit, kind=kind) + for with_jit in [True, False] + for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"]) + def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"): + x = np.array([7, 8, 9, 10], dtype=np.float32) + + if kind == "bad_rank": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x, + # Wrong shape rank + output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x) + elif kind == "bad_dim": + def fun_jax(x): + bad_shape = (5 + x.shape[0],) + y = jax2tf.call_tf(lambda x: x, + # Wrong dimension + output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x) + # JAX will believe that the following is Ok, leading to downstream error in TF + return y + jnp.ones(bad_shape, dtype=x.dtype) + elif kind == "bad_dtype": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x) + elif kind == "bad_dtype_x64": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x * np.float64(3.), + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x) + else: + assert False + expect_ex = ValueError + expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype" + + # Call without shape polymorphism + fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax)) + with self.assertRaisesRegex(expect_ex, expect_error): + fun_tf_rt(x) + + # Now with shape polymorphism + if kind == "bad_dim" and with_jit: + # TODO: in jit more the error pops up later, at AddV2 + expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" + if kind == "bad_dim" and config.jax2tf_default_experimental_native_lowering: + # TODO(b/268386622): call_tf with shape polymorphism and native lowering. + expect_error = "Error compiling TensorFlow function. call_tf can used .* only with compileable functions with static output shapes" + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + with self.assertRaisesRegex(expect_ex, expect_error): fun_tf_rt(x) def test_inner_native_lowering(self): @@ -1142,22 +1258,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertIn('op: "XlaCallModule"', f_outer_graph) self.assertNotIn('op: "Sin"', f_outer_graph) - @_parameterized_jit - def test_shape_polymorphism_static_output_shape(self, with_jit=True): - # TODO(b/268386622) Dynamic shapes not yet supported. - if config.jax2tf_default_experimental_native_lowering: - raise unittest.SkipTest("Skip test because of dynamic shapes.") - x = np.array([0.7, 0.8], dtype=np.float32) - - def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - - fun_jax = jax2tf.call_tf(fun_tf) - fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) - if with_jit: - fun_tf_rt = tf.function(jit_compile=True, autograph=False)(fun_tf_rt) - self.assertAllClose(fun_tf(x), fun_tf_rt(x)) - @parameterized.named_parameters( _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model, f4_function=f4_function, f4_saved_model=f4_saved_model) From 449e9f5ad53f6a252578c5d358735e919dc35832 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 4 Mar 2023 18:06:26 +0000 Subject: [PATCH 3/4] Cleanup the logic that was rolled back and doesn't exist in C++ but still exists in python. PiperOrigin-RevId: 514078264 --- jax/_src/array.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 378ccf27b..dab3e05a2 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -475,11 +475,8 @@ class ArrayImpl(basearray.Array): self._check_if_deleted() if self._npy_value is None: if self.is_fully_replicated: - arr = self._arrays[0] # type: ignore - # copy_to_host_async implemented in c++ only for single device arrays. - if hasattr(arr, "_copy_single_device_array_to_host_async"): - arr._copy_single_device_array_to_host_async() # type: ignore - return + self._arrays[0].copy_to_host_async() + return try: self.addressable_shards[0].replica_id replica_id_exists = True @@ -496,12 +493,7 @@ class ArrayImpl(basearray.Array): if self._npy_value is None: if self.is_fully_replicated: - arr = self._arrays[0] # type: ignore - # Conversion to numpy implemented only for single device arrays. - if hasattr(arr, "_single_device_array_to_np_array"): - self._npy_value = arr._single_device_array_to_np_array() # type: ignore - else: - self._npy_value = np.asarray(arr) # type: ignore + self._npy_value = np.asarray(self._arrays[0]) # type: ignore self._npy_value.flags.writeable = False return cast(np.ndarray, self._npy_value) From 2ccd785e167ab424f0c0d79de7b549a7f94aea61 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 4 Mar 2023 18:06:54 +0000 Subject: [PATCH 4/4] Internal change PiperOrigin-RevId: 514078303 --- CHANGELOG.md | 5 - jax/experimental/jax2tf/README.md | 83 ++++----- jax/experimental/jax2tf/call_tf.py | 169 +++++++----------- jax/experimental/jax2tf/jax2tf.py | 3 +- jax/experimental/jax2tf/tests/call_tf_test.py | 166 ++++------------- 5 files changed, 128 insertions(+), 298 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71891e99d..f82fecca1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,6 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.6 -* Changes - * {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`) - that can be used to declare the output shape and type of the result. This enables - {func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`). - ## jaxlib 0.4.6 ## jax 0.4.5 (Mar 2, 2023) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index a49be441c..4a40c8c5f 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -1229,57 +1229,6 @@ DeviceArray data or for np.ndarray that are aligned on 16-byte boundaries) and on GPU (for DeviceArray). The zero-copy does not yet work on TPU. -`call_tf` works even with shape polymorphism, but in that case -the user must pass the `output_shape_dtype` parameter to `call_tf` to declare -the expected output shapes. This allows JAX tracing to know the shape and -dtype of the results so that it can continue tracing the rest of the program. -When `output_shape_dtype` is not given (the default case), `call_tf` will -form a `tf.Graph` for the called TF function and will use the inferred -type and shape. However, in presence of dynamic shape the inferred TF -type will contain `None` for the dynamic dimensions, which is not enough -information for JAX shape polymorphism. - -For example: - -```python -def fun_jax(x): - y_shape = (x.shape[0] * 2, y.shape[1:]) - y = jax2tf.call_tf( - lambda x: tf.concat([x, x], axis=0), - output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x) - # JAX will know the y.shape - return jnp.ones(y.shape, dtype=y.dtype) + y - -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` - -An even simpler example for a function that returns the same shape as the input: - -```python -def fun_jax(x): - return jax2tf.call_tf(tf.math.sin, - output_shape_dtype=x) - )(x) - -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` - -If all the output shapes of the TF function are static, JAX does not need the -`output_shape_dtype` argument: - -```python -def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - -def fun_jax(x): - return jax2tf.call_tf(fun_tf)(x) - -# The following will not throw an error because the output shape of fun_tf is static. -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` - -The shape polymorphism support for `call_tf` does not yet work for native lowering. - ### Limitations of call_tf The TF function must be compileable (`tf.function(func, jit_compile=True)`) @@ -1363,6 +1312,38 @@ JAX computation runs on TPU. This will fail if the computation captures variables on some other devices. It is best to use ``call_tf`` with TF functions that do not capture variables. +A TF function wrapped with `call_tf` cannot be applied to inputs whose +shapes are not constants, unless all the output shapes of the TF function +are static. The may arise when you try to apply `jax2tf.convert` with +polymorphic shapes on the result of `call_tf`: + +```python +def fun_jax(x): + return jax2tf.call_tf(tf.math.sin)(x) + +# The following will throw an error. +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +This is unsatisfying, because the result of the above conversion +could be simply `tf.math.sin`, which is batch polymorphic. But +JAX cannot keep track of shapes through a `call_tf` call, and it +cannot be sure that the shape-polymorphic conversion is safe. + +If all the output shapes of the TF function are static, JAX does not need to +keep track of shapes after a `call_tf` call, hence allows shape-polymorphic +inputs in such cases: + +```python +def fun_tf(x): + return tf.math.reduce_sum(tf.math.sin(x)) + +def fun_jax(x): + return jax2tf.call_tf(fun_tf)(x) + +# The following will not throw an error because the output shape of fun_tf is static. +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` # Misc notes diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index ad8609ddc..78773021e 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -55,15 +55,13 @@ map = util.safe_map zip = util.safe_zip TfConcreteFunction = Any -TfVal = jax2tf_internal.TfVal # The platforms for which to use DLPack to avoid copying (only works on GPU # and CPU at the moment, and only for DeviceArray). For CPU we don't need # DLPack, if we are careful. _DLPACK_PLATFORMS = ("gpu",) -def call_tf(callable_tf: Callable, has_side_effects=True, - output_shape_dtype=None) -> Callable: +def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``callable_tf`` will be called with TensorFlow-compatible arguments ( @@ -92,14 +90,6 @@ def call_tf(callable_tf: Callable, has_side_effects=True, has_side_effects: if True then it ensures that instances of this primitive are not removed or replicated by JAX optimizations such as dead-code elimination. - output_shape_dtype: An optional declaration of the expected shapes and dtypes - from the called TensorFlow function. If given it will be used during JAX - tracing to form the abstract values of the results of the `call_tf`. If - not given then we form a `tf.Graph` for the called TensorFlow function and - we use the TensorFlow-inferred shapes and types. Must be a pytree matching the - structure of the nested structure returned from the TensorFlow function, - containing objects with `.shape` and `.dtype` attributes, - e.g., `jax.ShapeDtypeStruct` or `jax.Array`. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with @@ -123,58 +113,67 @@ def call_tf(callable_tf: Callable, has_side_effects=True, def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) a_tf_shape = [ - d if core.is_constant_dim(d) else None for d in a_jax.shape] + d if core.is_constant_dim(d) else None for d in a_jax.shape + ] return tf.TensorSpec(a_tf_shape, a_tf_dtype) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) - if output_shape_dtype is not None: - output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype) - output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat) - else: - output_avals, output_shape_dtype_tree = None, None + def check_tf_result(r_tf): + # Check that the TF function returns values of expected types. This + # improves error reporting, preventing hard-to-diagnose errors downstream + try: + jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) + except Exception as e: + msg = ("The called TF function returns a result that is not " + f"convertible to JAX: {r_tf}.") + raise ValueError(msg) from e res_treedef = None # We'll store here the result treedef res_tf_flat = None # For error reporting # The function below will be called at least once, either in eager - # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf() + # or in graph mode. def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: args_tf = args_treedef.unflatten(args_tf_flat) res_tf = callable_tf(*args_tf) nonlocal res_treedef, res_tf_flat res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf) - assert res_treedef is None or res_treedef == res_treedef_now, ( - f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}") + for r_tf in res_tf_flat: + check_tf_result(r_tf) + assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}" res_treedef = res_treedef_now - if output_avals is not None: - if res_treedef != output_shape_dtype_tree: - raise ValueError( - "The pytree of the TensorFlow function results does not match the " - "pytree of the declared output_shape_dtype:\n" - f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}") - assert len(output_avals) == len(res_tf_flat) - - checked_res_tf_flat = [ - check_tf_result(i, r_tf, r_aval) - for i, (r_tf, r_aval) in enumerate( - zip(res_tf_flat, - (output_avals if output_avals is not None - else (None,) * len(res_tf_flat))))] - return checked_res_tf_flat + return res_tf_flat # Prepare a tf.function ahead of time, to cache the concrete functions. This # won't be used in op-by-op execution mode. function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True) + input_shapes_tf = [s.shape for s in args_flat_sig_tf] + output_shapes_tf = _get_concrete_function_tf( + function_flat_tf, args_flat_sig_tf + ).output_shapes + + if not all(s.is_fully_defined() for s in input_shapes_tf) and not all( + s.is_fully_defined() for s in output_shapes_tf + ): + for a_jax, a_tf_shape in zip(args_flat_jax, input_shapes_tf): + if not a_tf_shape.is_fully_defined(): + msg = ( + "call_tf cannot be applied to shape-polymorphic arguments unless" + " all the output shapes are static. Found argument shape:" + f" {a_jax.shape}. See" + " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + res_jax_flat = call_tf_p.bind( *args_flat_jax, # Carry the actual function such that op-by-op call can call in TF eager mode. callable_flat_tf=callable_flat_tf, function_flat_tf=function_flat_tf, args_flat_sig_tf=args_flat_sig_tf, - output_avals=output_avals, has_side_effects=has_side_effects) - # We must have called callable_flat_tf by nοw assert res_treedef is not None # Sometimes, in compiled mode, we get a different number of results than we # got when tracing the TF function (and building the res_treedef). This @@ -249,44 +248,6 @@ def call_tf(callable_tf: Callable, has_side_effects=True, return util.wraps(callable_tf)(make_call) -def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal: - # Check that the TF function returns values of expected types. This - # improves error reporting, preventing hard-to-diagnose errors downstream - try: - jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) - except Exception as e: - msg = ("The called TF function returns a result that is not " - f"convertible to JAX: {r_tf}.") - raise ValueError(msg) from e - - if r_aval is None: - return r_tf - # We convert to TF type, and canonicalize to 32-bit if necessary - r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype) - # Checking shapes is trickier in presence of dynamic shapes. I wish we could - # check at runtime that the returned shape matches the declared shape. I wish - # that tf.ensure_shape did this, but it can only take shapes that contain None - # not computed shapes. However, in eager mode we should be able to resolve - # the declared shapes to constants and we get better checking. - if tf.executing_eagerly(): - r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape) - else: - r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval) - # We do as much checking as we can here, instead of relying on tf.ensure_shape - # because the latter gives different errors in eager vs. compiled mode. - if (r_tf.dtype != r_aval_dtype_tf or - len(r_tf.shape) != len(r_aval_shape_tf) or - any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d - for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))): - msg = ("The shapes or dtypes returned by the TensorFlow function " - "do not match the declared output_shape_dtype:\n" - f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]") - raise ValueError(msg) - # At this point tf.ensure_shape does not do much, it should never throw an - # error, albeit it may refine the shape a bit. - return tf.ensure_shape(r_tf, r_aval_shape_tf) - - call_tf_p = core.Primitive("call_tf") call_tf_p.multiple_results = True @@ -348,43 +309,39 @@ effects.remat_allowed_effects.add_type(CallTfEffect) effects.custom_derivatives_allowed_effects.add_type(CallTfEffect) -def _call_tf_abstract_eval(*args_flat_avals, +def _call_tf_abstract_eval(*_, function_flat_tf, args_flat_sig_tf, - has_side_effects, - output_avals, **__): + has_side_effects, **__): # Called only when we form a Jaxpr, i.e., under jit, scan, etc. - effects = {call_tf_effect} if has_side_effects else set() - # If not output_avals is given, then we ask TF to infer the output shapes. - # We call this even if output_avals is given because it will ensure that - # callable_flat_tf is called. Since _get_concrete_function_tf is cached - # there is a small cost of calling it more often than needed. concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf) - if output_avals is not None: - return output_avals, effects - def is_fully_known_shape(s): return s.rank is not None and all([d is not None for d in s]) + effects = {call_tf_effect} if has_side_effects else set() + if all([is_fully_known_shape(s) + for s in concrete_function_flat_tf.output_shapes]): + return ( + tuple([ + # We convert to JAX type, and canonicalize to 32-bit if necessary + core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) + for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, + concrete_function_flat_tf.output_shapes) + ]), + effects) - if all(is_fully_known_shape(s) - for s in concrete_function_flat_tf.output_shapes): - avals_from_tf = tuple( - # We convert to JAX type, and canonicalize to 32-bit if necessary - core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) - for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, - concrete_function_flat_tf.output_shapes)) - return avals_from_tf, effects - else: - msg = "call_tf cannot call functions whose output has dynamic shape. " - if any(not core.is_constant_shape(a.shape) for a in args_flat_avals): - msg += ("The function is called with shape-polymorphic inputs. Consider " - "using the `output_shape_dtype` argument to call_tf. ") - msg += ("\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" - " for a discussion.") - raise ValueError(msg) + # There are some cases when TF shape inference is not powerful enough to + # figure out the output shapes (e.g., b/128924522), even in situations where + # XLA can compile the code, from which we can get the shapes. + + # We use the "cpu" as the platform, since JAX abstract eval is not platform + # specific; the "cpu" backend is always available and for abstract evaluation + # it should not matter which platform we use. + _, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, + "CPU") + return tuple(result_avals), effects call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) @@ -415,11 +372,6 @@ def _code_generator_and_avals( ) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], Sequence[ir.Value]]], Sequence[core.ShapedArray]]: - # TODO(necula): we have refactored the code to not need to lower the code - # just in order to get the avals, so in fact the returned avals from this - # function are never used. We keep it here for now in case we detect - # a regressions, but if not we should simplify this function. - # Returns and caches a code generator (taking a builder and the # XlaOps for the arguments) and a sequence of result abstract shapes. @@ -526,7 +478,8 @@ def _register_call_lowering(platform): for platform in ("cpu", "cuda", "tpu"): _register_call_lowering(platform) -# Support the call_tf under jax2tf.convert in eager mode +# Support the call_tf under jax2tf.convert +TfVal = jax2tf_internal.TfVal def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c9b4efb35..4116ecbc2 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -335,7 +335,8 @@ def convert(fun_jax: Callable, _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() # TODO: is there a better way to check if we are inside a transformation? - if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: + if not core.trace_state_clean( + ) and not _thread_local_state.inside_call_tf: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index f4ecb8f89..370e3e925 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -44,12 +44,6 @@ def _maybe_jit(with_jit: bool, func: Callable) -> Callable: else: return func -def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable: - if with_jit: - return tf.function(func, autograph=False, jit_compile=True) - else: - return func - def _named_test(**kwargs): return dict(kwargs, testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())])) @@ -59,7 +53,8 @@ _parameterized_jit = parameterized.named_parameters( for with_jit in [True, False]) _call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions" -_call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape" +_call_tf_dynamic_shape_error = "Compiled TensorFlow function has dynamic output shape.* call_tf can used in a staged context .* only with compileable functions" + class CallTfTest(tf_test_util.JaxToTfTestCase): @@ -176,7 +171,8 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): x = np.array([True, False], dtype=np.bool_) self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode - with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, + _call_tf_dynamic_shape_error): jax.jit(f_jax)(x) def test_error_bad_result_tensorarray(self): @@ -573,7 +569,9 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(x[0:x[1]], res1) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + with self.assertRaisesRegex( + ValueError, "Compiled TensorFlow function has dynamic output shape" + ): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) @@ -1101,143 +1099,29 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(expected, res1, check_dtypes=False) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, + _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) # TODO(necula): this should work in op-by-op mode, but it fails because # jax2tf.convert does abstract evaluation. - with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, + _call_tf_dynamic_shape_error): fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x) - @_parameterized_jit - def test_shape_poly_static_output_shape(self, with_jit=True): - if config.jax2tf_default_experimental_native_lowering: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") - x = np.array([0.7, 0.8], dtype=np.float32) - + def test_shape_polymorphism_error(self): + x = np.array([.7, .8], dtype=np.float32) def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) + return tf.math.sin(x) fun_jax = jax2tf.call_tf(fun_tf) - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - self.assertAllClose(fun_tf(x), fun_tf_rt(x)) - - @_parameterized_jit - def test_shape_poly(self, with_jit=False): - if config.jax2tf_default_experimental_native_lowering: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - y = jax2tf.call_tf(tf.math.sin, - output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x) - z = jnp.cos(y) - w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0), - output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z) - assert w.shape[0] == 2 * x.shape[0] - return w - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - res_tf = fun_tf_rt(x) - self.assertAllClose(fun_jax(x), res_tf) - - @_parameterized_jit - def test_shape_poly_pytree_result(self, with_jit=True): - if config.jax2tf_default_experimental_native_lowering: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - # Returns a tuple - y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)), - output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), - jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x) - assert y[0].shape[0] == x.shape[0] - assert y[1].shape[0] == 2 * x.shape[0] - return y - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - res_tf = fun_tf_rt(x) - self.assertAllClose(fun_jax(x), res_tf) - - @_parameterized_jit - def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True): - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - return jax2tf.call_tf(tf.math.sin)(x) - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): - fun_tf_rt(x) - - @_parameterized_jit - def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False): - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - return jax2tf.call_tf(tf.math.sin, - output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), - jax.ShapeDtypeStruct(x.shape, x.dtype)))(x) - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) with self.assertRaisesRegex( - ValueError, - "The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"): - fun_tf_rt(x) - - @parameterized.named_parameters( - _named_test(with_jit=with_jit, kind=kind) - for with_jit in [True, False] - for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"]) - def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"): - x = np.array([7, 8, 9, 10], dtype=np.float32) - - if kind == "bad_rank": - def fun_jax(x): - return jax2tf.call_tf(lambda x: x, - # Wrong shape rank - output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x) - elif kind == "bad_dim": - def fun_jax(x): - bad_shape = (5 + x.shape[0],) - y = jax2tf.call_tf(lambda x: x, - # Wrong dimension - output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x) - # JAX will believe that the following is Ok, leading to downstream error in TF - return y + jnp.ones(bad_shape, dtype=x.dtype) - elif kind == "bad_dtype": - def fun_jax(x): - return jax2tf.call_tf(lambda x: x, - output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x) - elif kind == "bad_dtype_x64": - def fun_jax(x): - return jax2tf.call_tf(lambda x: x * np.float64(3.), - output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x) - else: - assert False - expect_ex = ValueError - expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype" - - # Call without shape polymorphism - fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax)) - with self.assertRaisesRegex(expect_ex, expect_error): - fun_tf_rt(x) - - # Now with shape polymorphism - if kind == "bad_dim" and with_jit: - # TODO: in jit more the error pops up later, at AddV2 - expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" - if kind == "bad_dim" and config.jax2tf_default_experimental_native_lowering: - # TODO(b/268386622): call_tf with shape polymorphism and native lowering. - expect_error = "Error compiling TensorFlow function. call_tf can used .* only with compileable functions with static output shapes" - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - with self.assertRaisesRegex(expect_ex, expect_error): + ValueError, "call_tf cannot be applied to shape-polymorphic arguments" + ): fun_tf_rt(x) def test_inner_native_lowering(self): @@ -1258,6 +1142,22 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertIn('op: "XlaCallModule"', f_outer_graph) self.assertNotIn('op: "Sin"', f_outer_graph) + @_parameterized_jit + def test_shape_polymorphism_static_output_shape(self, with_jit=True): + # TODO(b/268386622) Dynamic shapes not yet supported. + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("Skip test because of dynamic shapes.") + x = np.array([0.7, 0.8], dtype=np.float32) + + def fun_tf(x): + return tf.math.reduce_sum(tf.math.sin(x)) + + fun_jax = jax2tf.call_tf(fun_tf) + fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) + if with_jit: + fun_tf_rt = tf.function(jit_compile=True, autograph=False)(fun_tf_rt) + self.assertAllClose(fun_tf(x), fun_tf_rt(x)) + @parameterized.named_parameters( _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model, f4_function=f4_function, f4_saved_model=f4_saved_model)