Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.
PiperOrigin-RevId: 621341638
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
`jax.experimental.host_callback` is deprecated and any API in that module will throw a DeprecationWarning. After this change the `id_print` and `stop_outfeed_receiver` will throw an `AttributeError` in internal code only.
Add a deprecation message for `barrier_wait`.
PiperOrigin-RevId: 621064083
Prior to this change we had to import jax.experimental.pallas.{gpu,tpu} in
jax.experimental.pallas only to get the lowering rules registered.
PiperOrigin-RevId: 620957622
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.
PiperOrigin-RevId: 620357554
The jax.experimental.host_callback module is deprecated and will be removed.
See https://github.com/google/jax/issues/20385.
The other API entry points have been marked as deprecated already, but barrier_wait was missed.
PiperOrigin-RevId: 620237286
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.
If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.
Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
This is an attempt to re-land #19819 aka cl/607570860 after a small number of
performance regressions.
As before, the main changes are:
1. simplify the scan impl that we trace through to get the lowering, and
2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
jaxpr we already have in hand.
The main motivation was (2), but (1) seems like a useful win too.
The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before).
The code in #19819 was simpler in that it avoided reshapes, concats, and
un-concats. But it caused at least one apparent performance regression (an XLA
bug?) and it was unrelated to the original goal of reducing tracing time. So
here we just land the trace time improvement.
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.
The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.
PiperOrigin-RevId: 619292757
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.
Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.
Reverts 19e6156ccec0df7a900471df7840bc421da2898b
PiperOrigin-RevId: 619156176