Fixed printing order of results in jax.debug.print documentation.

This commit is contained in:
Sai-Suraj-27 2025-02-28 06:17:14 +00:00
parent d8953e5311
commit 56285aec6b

View File

@ -91,8 +91,8 @@ def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0