The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.
PiperOrigin-RevId: 658011228
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.
PiperOrigin-RevId: 658002501
This is required to allow the use of subslices: e.g., the two halves
of a TPU slice. One of them will not include the device at
coordinates (0, 0, 0).
E.g., assume we have a TPU v4 1x2x1 slice.
BEFORE THIS CL, if we call _get_physical_tpu_mesh() (an auxiliary for
the public create_device_mesh()) with
jax_devices=[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the expected result
[[[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
However, if we call it with
jax_devices=[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the wrong mesh
[[[None]
[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
That's because the code before this CL assumed the the incoming
jax_devices are arranged in a cuboid that starts at (0, 0, 0). When
working with subslices (e.g., half of a TPU slice) that is not always
the case.
AFTER THIS CL, the second case will return
[[[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
For each dimension from the TPU coordinates, this CL computes the min
/ max; we expect the provided devices to fill the [min, max] interval
(in that dimension). By requesting this for each dimension, we
request that the set of provided devices constitute a cuboid, but,
unlike before this CL, that cuboid does not need to include (0, 0, 0):
it can be "translated", which allows e.g., both half-slices of a big
slice.
PiperOrigin-RevId: 657902201
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.
Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.
PiperOrigin-RevId: 657225917
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.
PiperOrigin-RevId: 657186057
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.
I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.
I added entries to pallas/CHANGELOG.