mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make utils for reporting function name work with functools.partial
by using the inner .func
attribute if the object doesn't have a __name__
attribute. functools.partial
objects do not have __name__
attributes by default.
PiperOrigin-RevId: 715881812
This commit is contained in:
parent
41993fdb24
commit
f7d097f7cc
@ -376,11 +376,21 @@ def wrap_name(name, transform_name):
|
||||
return transform_name + '(' + name + ')'
|
||||
|
||||
def fun_name(fun: Callable):
|
||||
return getattr(fun, "__name__", "<unnamed function>")
|
||||
name = getattr(fun, "__name__", None)
|
||||
if name is not None:
|
||||
return name
|
||||
if isinstance(fun, partial):
|
||||
return fun_name(fun.func)
|
||||
else:
|
||||
return "<unnamed function>"
|
||||
|
||||
def fun_qual_name(fun: Callable):
|
||||
return getattr(fun, "__qualname__",
|
||||
getattr(fun, "__name__", "<unnamed function>"))
|
||||
qual_name = getattr(fun, "__qualname__", None)
|
||||
if qual_name is not None:
|
||||
return qual_name
|
||||
if isinstance(fun, partial):
|
||||
return fun_qual_name(fun.func)
|
||||
return fun_name(fun)
|
||||
|
||||
def canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
@ -678,7 +688,6 @@ class StrictABC(metaclass=StrictABCMeta):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
|
||||
test_event_listener: Callable | None = None
|
||||
|
||||
def test_event(name: str, *args) -> None:
|
||||
|
@ -96,6 +96,38 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
jitted = jit(my_function)
|
||||
self.assertEqual(repr(jitted), f"<PjitFunction of {repr(my_function)}>")
|
||||
|
||||
def test_fun_name(self):
|
||||
def my_function():
|
||||
return
|
||||
|
||||
with self.subTest("function"):
|
||||
jitted = jit(my_function)
|
||||
self.assertEqual(
|
||||
jitted.__getstate__()["function_name"], my_function.__name__
|
||||
)
|
||||
with self.subTest("default_partial"):
|
||||
my_partial = partial(my_function)
|
||||
jitted = jit(my_partial)
|
||||
self.assertEqual(
|
||||
jitted.__getstate__()["function_name"], my_function.__name__
|
||||
)
|
||||
with self.subTest("nested_default_partial"):
|
||||
my_partial = partial(partial(my_function))
|
||||
jitted = jit(my_partial)
|
||||
self.assertEqual(
|
||||
jitted.__getstate__()["function_name"], my_function.__name__
|
||||
)
|
||||
with self.subTest("named_partial"):
|
||||
my_partial = partial(my_function)
|
||||
my_partial.__name__ = "my_partial"
|
||||
jitted = jit(my_partial)
|
||||
self.assertEqual(
|
||||
jitted.__getstate__()["function_name"], my_partial.__name__
|
||||
)
|
||||
with self.subTest("lambda"):
|
||||
jitted = jit(lambda: my_function())
|
||||
self.assertEqual(jitted.__getstate__()["function_name"], "<lambda>")
|
||||
|
||||
def test_jit_repr_errors(self):
|
||||
class Callable:
|
||||
def __call__(self): pass
|
||||
@ -288,14 +320,14 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
|
||||
def test_jit_default_platform(self):
|
||||
with jax.default_device("cpu"):
|
||||
result = jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEqual(result.device.platform, "cpu")
|
||||
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])
|
||||
|
||||
with jax.default_device("cpu"):
|
||||
result = jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEqual(result.device.platform, jax.default_backend())
|
||||
self.assertEqual(result.device, jax.local_devices()[0])
|
||||
self.assertEqual(result.device.platform, "cpu")
|
||||
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])
|
||||
|
||||
result = jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEqual(result.device.platform, jax.default_backend())
|
||||
self.assertEqual(result.device, jax.local_devices()[0])
|
||||
|
||||
def test_complex_support(self):
|
||||
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
|
||||
|
Loading…
x
Reference in New Issue
Block a user