Export KeyPath and related types to jax.tree_util

These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
This commit is contained in:
jax authors 2024-07-31 06:40:53 -07:00 committed by jax authors
parent 9dba6eb16a
commit a207fe9b77
3 changed files with 12 additions and 4 deletions

View File

@ -26,6 +26,8 @@ List of Functions
treedef_children treedef_children
treedef_is_leaf treedef_is_leaf
treedef_tuple treedef_tuple
KeyEntry
KeyPath
keystr keystr
Legacy APIs Legacy APIs

View File

@ -42,6 +42,8 @@ from jax._src.tree_util import (
DictKey as DictKey, DictKey as DictKey,
FlattenedIndexKey as FlattenedIndexKey, FlattenedIndexKey as FlattenedIndexKey,
GetAttrKey as GetAttrKey, GetAttrKey as GetAttrKey,
KeyEntry as KeyEntry,
KeyPath as KeyPath,
Partial as Partial, Partial as Partial,
PyTreeDef as PyTreeDef, PyTreeDef as PyTreeDef,
SequenceKey as SequenceKey, SequenceKey as SequenceKey,

View File

@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str
class PackageStructureTest(jtu.JaxTestCase): class PackageStructureTest(jtu.JaxTestCase):
@parameterized.parameters([ @parameterized.parameters([
# TODO(jakevdp): expand test to other public modules. # TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"), _mod("jax.errors"),
_mod("jax.nn.initializers"), _mod("jax.nn.initializers"),
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']), _mod(
"jax.tree_util",
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
),
]) ])
def test_exported_names_match_module(self, module_name, include, exclude): def test_exported_names_match_module(self, module_name, include, exclude):
"""Test that all public exports have __module__ set correctly.""" """Test that all public exports have __module__ set correctly."""