Merge pull request #18023 from jakevdp:make-jaxpr-name

PiperOrigin-RevId: 572094635
This commit is contained in:
jax authors 2023-10-09 18:22:26 -07:00
commit 4c306381c9
2 changed files with 13 additions and 1 deletions

View File

@ -2480,7 +2480,11 @@ def make_jaxpr(fun: Callable,
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr
make_jaxpr_f.__name__ = f"make_jaxpr({make_jaxpr.__name__})"
make_jaxpr_f.__module__ = "jax"
if hasattr(fun, "__qualname__"):
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
if hasattr(fun, "__name__"):
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f
def _infer_src_sharding(src, x):

View File

@ -4351,6 +4351,14 @@ class APITest(jtu.JaxTestCase):
jax.make_jaxpr(Foo(1))(3) # don't crash
def test_make_jaxpr_name(self):
def foo(x, y, z):
return x + y + z
jfoo = jax.make_jaxpr(foo)
self.assertEqual(jfoo.__name__, f"make_jaxpr({foo.__name__})")
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
self.assertEqual(jfoo.__module__, "jax")
def test_inner_jit_function_retracing(self):
# https://github.com/google/jax/issues/7155
inner_count = outer_count = 0