diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 73ac02628..85580120c 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -91,8 +91,8 @@ def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) -# Prints: x: 1.0 -# x: 0.0 +# Prints: x: 0.0 +# x: 1.0 # OR # Prints: x: 1.0 # x: 0.0