Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.
PiperOrigin-RevId: 621341638
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
The default thread pool size is too small on Mac OS.
An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.
Fixes https://github.com/google/jax/issues/20428
PiperOrigin-RevId: 620853544
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.
PiperOrigin-RevId: 620357554
There is no data dependence between these breakpoints (the breakpoints are lowered into custom call that returns nothing, so there is no way to enforce their relative order)
Thus we are relaxing this ordering constraint in debugger test for all backends.
PiperOrigin-RevId: 620355448
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.
If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.
Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
Unfortunately, upstream Triton has decided to drop support for NVIDIA GPUs
below Ampere, so we bump the GPU version requirements for using Triton.
PiperOrigin-RevId: 619899728
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.
The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.
PiperOrigin-RevId: 619292757
"test_pipeline_all_gather_matmul" is the best demo example of nested pallas pipelines, but it's hard to follow the logic in the existing test.
A few changes were made there:
- rename things to avoid confusion between outer and inner loop prologues / epilogues.
- give clear names for the outer iteration space: (step, phase) to help clarify sequencing of compute and DMAs.
- simplify and lift out all async copy definitions and add commentary on their function
- remove some incorrect comments about the rDMA schedule, and generally add a ton of commentary about when things happen in the outer pipeline.
- lift all the outer prologue work into an integrated prologue function
- various other small things.
PiperOrigin-RevId: 619254981
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.
Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.
Reverts 19e6156ccec0df7a900471df7840bc421da2898b
PiperOrigin-RevId: 619156176