This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.
PiperOrigin-RevId: 570189648
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.
We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.
Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.
The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:
* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.
PiperOrigin-RevId: 569062402
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.
In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.
```
jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5)
jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 560141018
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.
Example of new output:
```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
return %0 : tensor<i32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```
Note debug information must be enabled.
PiperOrigin-RevId: 549325621
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.
I'll fix the TODOs, etc in follow up CLs.
Fixes https://github.com/google/jax/issues/10539
PiperOrigin-RevId: 547612861
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
LoadedExecutable.cost_analysis, then fallback to the client method.
PiperOrigin-RevId: 542671990