mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 701613257
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests of the JAX public package structure"""
|
|
|
|
from collections.abc import Sequence
|
|
import importlib
|
|
import types
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
from jax._src import test_util as jtu
|
|
|
|
|
|
def _mod(module_name: str, *, include: Sequence[str] = (), exclude: Sequence[str] = ()):
|
|
return {"module_name": module_name, "include": include, "exclude": exclude}
|
|
|
|
|
|
class PackageStructureTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([
|
|
# TODO(jakevdp): expand test to other public modules.
|
|
_mod("jax.errors", exclude=["JaxRuntimeError"]),
|
|
_mod(
|
|
"jax.numpy",
|
|
exclude=[
|
|
"array_repr",
|
|
"array_str",
|
|
"can_cast",
|
|
"character",
|
|
"complexfloating",
|
|
"dtype",
|
|
"iinfo",
|
|
"index_exp",
|
|
"inexact",
|
|
"integer",
|
|
"iterable",
|
|
"finfo",
|
|
"flexible",
|
|
"floating",
|
|
"generic",
|
|
"get_printoptions",
|
|
"ndarray",
|
|
"ndim",
|
|
"number",
|
|
"object_",
|
|
"printoptions",
|
|
"save",
|
|
"savez",
|
|
"set_printoptions",
|
|
"shape",
|
|
"signedinteger",
|
|
"size",
|
|
"s_",
|
|
"unsignedinteger",
|
|
"ComplexWarning",
|
|
],
|
|
),
|
|
_mod("jax.numpy.linalg"),
|
|
_mod("jax.nn.initializers"),
|
|
_mod(
|
|
"jax.tree_util",
|
|
exclude=[
|
|
"PyTreeDef",
|
|
"default_registry",
|
|
"KeyEntry",
|
|
"KeyPath",
|
|
"DictKey",
|
|
"GetAttrKey",
|
|
"SequenceKey",
|
|
"FlattenedIndexKey",
|
|
],
|
|
),
|
|
])
|
|
def test_exported_names_match_module(self, module_name, include, exclude):
|
|
"""Test that all public exports have __module__ set correctly."""
|
|
module = importlib.import_module(module_name)
|
|
self.assertEqual(module.__name__, module_name)
|
|
for name in dir(module):
|
|
if name not in include and (name.startswith('_') or name in exclude):
|
|
continue
|
|
obj = getattr(module, name)
|
|
if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)):
|
|
# No __module__ attribute expected.
|
|
continue
|
|
self.assertEqual(obj.__module__, module_name,
|
|
f"{obj} has {obj.__module__=}, expected {module_name}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|