Export KeyPath and related types to jax.tree_util

These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
This commit is contained in:
jax authors 2024-07-31 06:40:53 -07:00 committed by jax authors
parent 9dba6eb16a
commit a207fe9b77
3 changed files with 12 additions and 4 deletions

View File

@ -26,6 +26,8 @@ List of Functions
treedef_children
treedef_is_leaf
treedef_tuple
KeyEntry
KeyPath
keystr
Legacy APIs

View File

@ -42,6 +42,8 @@ from jax._src.tree_util import (
DictKey as DictKey,
FlattenedIndexKey as FlattenedIndexKey,
GetAttrKey as GetAttrKey,
KeyEntry as KeyEntry,
KeyPath as KeyPath,
Partial as Partial,
PyTreeDef as PyTreeDef,
SequenceKey as SequenceKey,

View File

@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str
class PackageStructureTest(jtu.JaxTestCase):
@parameterized.parameters([
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.nn.initializers"),
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']),
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
),
])
def test_exported_names_match_module(self, module_name, include, exclude):
"""Test that all public exports have __module__ set correctly."""