Since Shardy is inside the middle of the XLA pipeline, after converting down to HLO, we need to run the Shardy export pipeline to preserve the SDY ops and sharding attributes for when we come back from HLO to MLIR when Shardy propagation is run.
PiperOrigin-RevId: 658040672
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.
PiperOrigin-RevId: 658011228
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.
In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.
This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:
1. This feature is currently implemented in "initial style", which means
that the `fwd` function is traced to a jaxpr when it is initially
called. This means that when `optimize_remat=True`, the `custom_vjp`
function doesn't support data dependent conditionals within `fwd`.
This isn't a fundamental limitation of the method, but this
implementation is much simpler so it seemed like a good place to
start, and much of the complexity of the "final style" version of
this logic should be simplified by work that @dougalm is doing.
Furthermore, for the immediate use case of opaque calls, initial
style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
this isn't a required restriction, but I chose to start without this
added complexity and we can add support for symbolic zeros as needed
in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
currently implement rules for composing with the AD system. This
means that a `custom_vjp` constructed with `optimize_remat=True`
won't currently work with some approaches to higher-order AD. I
expect I know how to fix that and will either include that here or in
a follow-up.
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.
PiperOrigin-RevId: 658002501
This is required to allow the use of subslices: e.g., the two halves
of a TPU slice. One of them will not include the device at
coordinates (0, 0, 0).
E.g., assume we have a TPU v4 1x2x1 slice.
BEFORE THIS CL, if we call _get_physical_tpu_mesh() (an auxiliary for
the public create_device_mesh()) with
jax_devices=[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the expected result
[[[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
However, if we call it with
jax_devices=[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the wrong mesh
[[[None]
[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
That's because the code before this CL assumed the the incoming
jax_devices are arranged in a cuboid that starts at (0, 0, 0). When
working with subslices (e.g., half of a TPU slice) that is not always
the case.
AFTER THIS CL, the second case will return
[[[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
For each dimension from the TPU coordinates, this CL computes the min
/ max; we expect the provided devices to fill the [min, max] interval
(in that dimension). By requesting this for each dimension, we
request that the set of provided devices constitute a cuboid, but,
unlike before this CL, that cuboid does not need to include (0, 0, 0):
it can be "translated", which allows e.g., both half-slices of a big
slice.
PiperOrigin-RevId: 657902201
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py