Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print

PiperOrigin-RevId: 735511953
This commit is contained in:
jax authors 2025-03-10 14:26:45 -07:00
commit b8590816bf

View File

@ -91,8 +91,8 @@ def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0