diff --git a/jax/_src/util.py b/jax/_src/util.py index 3e52a2584..0e28aea04 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -668,17 +668,12 @@ def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]: exclude_methods = {'__module__', '__dict__', '__doc__'} - originals = {} for attr_name, attr in cls.__dict__.items(): if attr_name not in exclude_methods: - if hasattr(_original_func(attr), "_use_cpp"): - originals[attr_name] = attr - else: + if not hasattr(_original_func(attr), "_use_cpp"): setattr(cpp_cls, attr_name, attr) cpp_cls.__doc__ = cls.__doc__ - # TODO(pschuh): Remove once fastpath is gone. - cpp_cls._original_py_fns = originals return cpp_cls return wrapper