Fix keyword argument confusion in jax.profiler.annotate_function decorator.

This commit is contained in:
Peter Hawkins 2021-12-06 09:10:44 -05:00
parent 14bc95fe1b
commit a288b154b4
2 changed files with 11 additions and 5 deletions

View File

@ -169,7 +169,8 @@ class StepTraceContext(StepTraceAnnotation):
super().__init__(*args, **kwargs)
def annotate_function(func: Callable, name: Optional[str] = None, **kwargs):
def annotate_function(func: Callable, name: Optional[str] = None,
**decorator_kwargs):
"""Decorator that generates a trace event for the execution of a function.
For example:
@ -198,7 +199,7 @@ def annotate_function(func: Callable, name: Optional[str] = None, **kwargs):
name = name or func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
return func(*args, **kwargs)
return wrapper
return wrapper

View File

@ -117,9 +117,14 @@ class ProfilerTest(unittest.TestCase):
def testTraceFunction(self):
@jax.profiler.annotate_function
def f(x):
return x + 2
self.assertEqual(f(7), 9)
def f(x, *, y):
return x + 2 * y
self.assertEqual(f(7, y=3), 13)
@jax.profiler.annotate_function
def f(x, *, name):
return x + 2 * len(name)
self.assertEqual(f(7, name="abc"), 13)
@partial(jax.profiler.annotate_function, name="aname")
def g(x):