mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
PiperOrigin-RevId: 735511953
This commit is contained in:
commit
b8590816bf
@ -91,8 +91,8 @@ def f(x):
|
||||
jax.debug.print("x: {}", x)
|
||||
return x
|
||||
jax.pmap(f)(xs)
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
# Prints: x: 0.0
|
||||
# x: 1.0
|
||||
# OR
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user