This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.
We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.
So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.
TODO(mattjj): add more testing (maybe in follow-ups)
When we write `vmap(f, spmd_axis_name=A)`, we require that `f` does not mention
A in specs, like the `PartitionSpec` in a `with_sharding_constraint` or the
`in_specs`/`out_specs` of `shard_map`. Previously, shard_map autodiff violated
that requirement, since we gave residuals sharding over all mesh axes (i.e.
including axis name A present in the mesh). As a result, the vmap rule could
then insert a redundant appearance of A.
This commit fixes the problem by only sharding over mesh axes mentioned in
in_specs; residuals can at most be sharded over those mesh axes. Then the vmap
rule is free to introduce an occurrence of A in the specs.
Currently only 2D shapes are supported in dot() lowering; the exception, however, gives a poor understanding of the problem.
The raised exception lists the associated shapes, but without knowing the 2D limitations, it provides little direction to the user on how to remedy the problem.
This change converts the raised exception to read something like:
`Exception: Only 2D tensors supported in dot; received: [ShapedArray(float32[128,128]), ShapedArray(float32[128])]`
rather than:
`Exception: [ShapedArray(float32[128,128]), ShapedArray(float32[128])]`
PiperOrigin-RevId: 630116468
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 630060418
This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend.
PiperOrigin-RevId: 630055511
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.
Also enable serialization of such calls.
PiperOrigin-RevId: 630033989
Before, dtype used to be in the metadata field of tensorstore spec because of it was the legacy way to config the dtype. This setting doesn't understand the "str" name, hence, there was special logic to translate bfloat for example.
This CL moves it out of the metadata field and put the dtype directly into the Tensorstore spec to eliminate special dtype translation logic. This will also add support of other quantized types such as int4.
PiperOrigin-RevId: 629845048
Add the exception to the formatted string which is being re-raised, so that we present the problem more clearly.
This makes debugging significantly easier -- for exceptions like "Unimplemented primitive in Pallas TPU lowering: sign", such text currently does not appear in the error output.
PiperOrigin-RevId: 629767145