There are 2 reasons for doing this:
* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.
* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.
Also fixes: https://github.com/google/jax/issues/17422
PiperOrigin-RevId: 650621659
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.
I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.
This is a follow up to #22269.
We make the following improvements:
* pytree structural disequality messages now attempt to localize the
mismatch using tree_util.KeyPath.
* we generate a simpler error message for when `in_specs` is not
a sequence, instead of the current PyTreeDef mismatch error.
* we generate an error message for when the index map function
in a BlockSpec returns an unexpected number of results.
* added error localization to the existing shape polymorphism
check that the block shapes are static.
* We check that the kernel function returns None. Without this
we used to get `body_fun output and input must have same type structure`
in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
on TPU.
* we check that the rank of the block_shape matches the rank of
the overall array. Without this we used to get a `safe_zip`
error. We also carry the pytree paths to localize the error.
To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.
To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.
memory space is exposed via JAX memories API so it doesn't have to be in the layout API.
Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.
Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.
PiperOrigin-RevId: 647487510
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
Refactoring only, NFC intended.
* add types to more places.
* don't unpack PjitInfo positionally, since it's a 23-tuple and that seems rather error prone.
* change _infer_params to produce a new PjitParams NamedTuple, rather than having callers unpack a 9-tuple positionally.
* inline _pjit_jaxpr into its caller, since it only has one caller and the wrapper doesn't really clarify anything.
* note the return type of transformation_with_aux is a Callable.
PiperOrigin-RevId: 645068326
The motivation for doing this is 2-fold:
1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.
2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.
Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.
Designed with @froystig!
PiperOrigin-RevId: 644087787
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.
Since we can keep the existing behavior and still merge the implementation is a good cleanup!
Fixes https://github.com/google/jax/issues/21116
PiperOrigin-RevId: 641347140
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.
Why do this?
* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!
* Having these 2 types makes things very confusing. One example is:
* `jax.jit` only accepts XLACompatibleShardings.
* `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.
PiperOrigin-RevId: 640527070
--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:
[export] Fix
A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".
While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.
1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.
2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.
3. The profile session runner should be passed to pxla.py and then called.
4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h
5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.
PiperOrigin-RevId: 638197166
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.
We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).
The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.