mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix keyword argument confusion in jax.profiler.annotate_function decorator.
This commit is contained in:
parent
14bc95fe1b
commit
a288b154b4
@ -169,7 +169,8 @@ class StepTraceContext(StepTraceAnnotation):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def annotate_function(func: Callable, name: Optional[str] = None, **kwargs):
|
||||
def annotate_function(func: Callable, name: Optional[str] = None,
|
||||
**decorator_kwargs):
|
||||
"""Decorator that generates a trace event for the execution of a function.
|
||||
|
||||
For example:
|
||||
@ -198,7 +199,7 @@ def annotate_function(func: Callable, name: Optional[str] = None, **kwargs):
|
||||
name = name or func.__name__
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with TraceAnnotation(name, **kwargs):
|
||||
with TraceAnnotation(name, **decorator_kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
@ -117,9 +117,14 @@ class ProfilerTest(unittest.TestCase):
|
||||
|
||||
def testTraceFunction(self):
|
||||
@jax.profiler.annotate_function
|
||||
def f(x):
|
||||
return x + 2
|
||||
self.assertEqual(f(7), 9)
|
||||
def f(x, *, y):
|
||||
return x + 2 * y
|
||||
self.assertEqual(f(7, y=3), 13)
|
||||
|
||||
@jax.profiler.annotate_function
|
||||
def f(x, *, name):
|
||||
return x + 2 * len(name)
|
||||
self.assertEqual(f(7, name="abc"), 13)
|
||||
|
||||
@partial(jax.profiler.annotate_function, name="aname")
|
||||
def g(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user