mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #18023 from jakevdp:make-jaxpr-name
PiperOrigin-RevId: 572094635
This commit is contained in:
commit
4c306381c9
@ -2480,7 +2480,11 @@ def make_jaxpr(fun: Callable,
|
||||
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
|
||||
return closed_jaxpr
|
||||
|
||||
make_jaxpr_f.__name__ = f"make_jaxpr({make_jaxpr.__name__})"
|
||||
make_jaxpr_f.__module__ = "jax"
|
||||
if hasattr(fun, "__qualname__"):
|
||||
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
||||
if hasattr(fun, "__name__"):
|
||||
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
||||
return make_jaxpr_f
|
||||
|
||||
def _infer_src_sharding(src, x):
|
||||
|
@ -4351,6 +4351,14 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
jax.make_jaxpr(Foo(1))(3) # don't crash
|
||||
|
||||
def test_make_jaxpr_name(self):
|
||||
def foo(x, y, z):
|
||||
return x + y + z
|
||||
jfoo = jax.make_jaxpr(foo)
|
||||
self.assertEqual(jfoo.__name__, f"make_jaxpr({foo.__name__})")
|
||||
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
|
||||
self.assertEqual(jfoo.__module__, "jax")
|
||||
|
||||
def test_inner_jit_function_retracing(self):
|
||||
# https://github.com/google/jax/issues/7155
|
||||
inner_count = outer_count = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user