Make utils for reporting function name work with functools.partial by using the inner .func attribute if the object doesn't have a __name__ attribute. functools.partial objects do not have __name__ attributes by default.

PiperOrigin-RevId: 715881812
This commit is contained in:
Zachary Garrett 2025-01-15 11:40:17 -08:00 committed by jax authors
parent 41993fdb24
commit f7d097f7cc
2 changed files with 52 additions and 11 deletions

View File

@ -376,11 +376,21 @@ def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def fun_name(fun: Callable):
return getattr(fun, "__name__", "<unnamed function>")
name = getattr(fun, "__name__", None)
if name is not None:
return name
if isinstance(fun, partial):
return fun_name(fun.func)
else:
return "<unnamed function>"
def fun_qual_name(fun: Callable):
return getattr(fun, "__qualname__",
getattr(fun, "__name__", "<unnamed function>"))
qual_name = getattr(fun, "__qualname__", None)
if qual_name is not None:
return qual_name
if isinstance(fun, partial):
return fun_qual_name(fun.func)
return fun_name(fun)
def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
@ -678,7 +688,6 @@ class StrictABC(metaclass=StrictABCMeta):
__slots__ = ()
test_event_listener: Callable | None = None
def test_event(name: str, *args) -> None:

View File

@ -96,6 +96,38 @@ class JitTest(jtu.BufferDonationTestCase):
jitted = jit(my_function)
self.assertEqual(repr(jitted), f"<PjitFunction of {repr(my_function)}>")
def test_fun_name(self):
def my_function():
return
with self.subTest("function"):
jitted = jit(my_function)
self.assertEqual(
jitted.__getstate__()["function_name"], my_function.__name__
)
with self.subTest("default_partial"):
my_partial = partial(my_function)
jitted = jit(my_partial)
self.assertEqual(
jitted.__getstate__()["function_name"], my_function.__name__
)
with self.subTest("nested_default_partial"):
my_partial = partial(partial(my_function))
jitted = jit(my_partial)
self.assertEqual(
jitted.__getstate__()["function_name"], my_function.__name__
)
with self.subTest("named_partial"):
my_partial = partial(my_function)
my_partial.__name__ = "my_partial"
jitted = jit(my_partial)
self.assertEqual(
jitted.__getstate__()["function_name"], my_partial.__name__
)
with self.subTest("lambda"):
jitted = jit(lambda: my_function())
self.assertEqual(jitted.__getstate__()["function_name"], "<lambda>")
def test_jit_repr_errors(self):
class Callable:
def __call__(self): pass
@ -288,14 +320,14 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(f(1).devices(), system_default_devices)
def test_jit_default_platform(self):
with jax.default_device("cpu"):
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, "cpu")
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])
with jax.default_device("cpu"):
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())
self.assertEqual(result.device, jax.local_devices()[0])
self.assertEqual(result.device.platform, "cpu")
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())
self.assertEqual(result.device, jax.local_devices()[0])
def test_complex_support(self):
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)