mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Export KeyPath
and related types to jax.tree_util
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them. PiperOrigin-RevId: 657987755
This commit is contained in:
parent
9dba6eb16a
commit
a207fe9b77
@ -26,6 +26,8 @@ List of Functions
|
|||||||
treedef_children
|
treedef_children
|
||||||
treedef_is_leaf
|
treedef_is_leaf
|
||||||
treedef_tuple
|
treedef_tuple
|
||||||
|
KeyEntry
|
||||||
|
KeyPath
|
||||||
keystr
|
keystr
|
||||||
|
|
||||||
Legacy APIs
|
Legacy APIs
|
||||||
|
@ -42,6 +42,8 @@ from jax._src.tree_util import (
|
|||||||
DictKey as DictKey,
|
DictKey as DictKey,
|
||||||
FlattenedIndexKey as FlattenedIndexKey,
|
FlattenedIndexKey as FlattenedIndexKey,
|
||||||
GetAttrKey as GetAttrKey,
|
GetAttrKey as GetAttrKey,
|
||||||
|
KeyEntry as KeyEntry,
|
||||||
|
KeyPath as KeyPath,
|
||||||
Partial as Partial,
|
Partial as Partial,
|
||||||
PyTreeDef as PyTreeDef,
|
PyTreeDef as PyTreeDef,
|
||||||
SequenceKey as SequenceKey,
|
SequenceKey as SequenceKey,
|
||||||
|
@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str
|
|||||||
|
|
||||||
|
|
||||||
class PackageStructureTest(jtu.JaxTestCase):
|
class PackageStructureTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
# TODO(jakevdp): expand test to other public modules.
|
# TODO(jakevdp): expand test to other public modules.
|
||||||
_mod("jax.errors"),
|
_mod("jax.errors"),
|
||||||
_mod("jax.nn.initializers"),
|
_mod("jax.nn.initializers"),
|
||||||
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']),
|
_mod(
|
||||||
|
"jax.tree_util",
|
||||||
|
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
|
||||||
|
),
|
||||||
])
|
])
|
||||||
def test_exported_names_match_module(self, module_name, include, exclude):
|
def test_exported_names_match_module(self, module_name, include, exclude):
|
||||||
"""Test that all public exports have __module__ set correctly."""
|
"""Test that all public exports have __module__ set correctly."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user