Sergei Lebedev 6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
..
2024-04-17 15:13:17 -07:00
2024-03-18 15:12:33 -07:00
2023-11-15 22:35:52 -05:00
2024-04-09 03:10:04 -07:00
2024-04-18 16:14:55 -07:00
2024-04-18 04:04:10 -07:00
2023-07-24 14:38:20 -07:00
2024-02-13 03:59:56 -08:00
2023-10-10 08:46:36 -07:00
2024-04-13 08:18:33 +01:00