mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Export KeyPath
and related types to jax.tree_util
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them. PiperOrigin-RevId: 657987755
This commit is contained in:
parent
9dba6eb16a
commit
a207fe9b77
@ -26,6 +26,8 @@ List of Functions
|
||||
treedef_children
|
||||
treedef_is_leaf
|
||||
treedef_tuple
|
||||
KeyEntry
|
||||
KeyPath
|
||||
keystr
|
||||
|
||||
Legacy APIs
|
||||
|
@ -42,6 +42,8 @@ from jax._src.tree_util import (
|
||||
DictKey as DictKey,
|
||||
FlattenedIndexKey as FlattenedIndexKey,
|
||||
GetAttrKey as GetAttrKey,
|
||||
KeyEntry as KeyEntry,
|
||||
KeyPath as KeyPath,
|
||||
Partial as Partial,
|
||||
PyTreeDef as PyTreeDef,
|
||||
SequenceKey as SequenceKey,
|
||||
|
@ -28,11 +28,15 @@ def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str
|
||||
|
||||
|
||||
class PackageStructureTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
# TODO(jakevdp): expand test to other public modules.
|
||||
_mod("jax.errors"),
|
||||
_mod("jax.nn.initializers"),
|
||||
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']),
|
||||
# TODO(jakevdp): expand test to other public modules.
|
||||
_mod("jax.errors"),
|
||||
_mod("jax.nn.initializers"),
|
||||
_mod(
|
||||
"jax.tree_util",
|
||||
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
|
||||
),
|
||||
])
|
||||
def test_exported_names_match_module(self, module_name, include, exclude):
|
||||
"""Test that all public exports have __module__ set correctly."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user