tree_util: fix warning category and stacklevel

This commit is contained in:
Jake VanderPlas 2022-07-15 09:24:22 -07:00
parent 10720258ea
commit 6907dfad00

View File

@ -527,7 +527,8 @@ def _deprecate(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
warnings.warn(f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
f"Use jax.tree_util.{f.__name__} instead.")
f"Use jax.tree_util.{f.__name__} instead.",
category=FutureWarning, stacklevel=2)
return f(*args, **kwargs)
return wrapped