If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.
In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.
```
jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5)
jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 560141018
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.
Example of new output:
```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
return %0 : tensor<i32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```
Note debug information must be enabled.
PiperOrigin-RevId: 549325621
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.
I'll fix the TODOs, etc in follow up CLs.
Fixes https://github.com/google/jax/issues/10539
PiperOrigin-RevId: 547612861
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
LoadedExecutable.cost_analysis, then fallback to the client method.
PiperOrigin-RevId: 542671990
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:
[Rollback] Update required Python version to 3.9
PiperOrigin-RevId: 528905991
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.
PiperOrigin-RevId: 524013922
--
75a7e7a07d58e14de73190d060414fd3a1ba3d52 by Matthew Johnson <mattjj@google.com>:
Handle jaxpr-round-tripping of custom jvp rules w/ sym zero
fixes#14833
Co-authored-by: Roy Frostig <frostig@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15426 from mattjj:custom-jvp-symbolic-zeros-3 75a7e7a07d58e14de73190d060414fd3a1ba3d52
PiperOrigin-RevId: 523817551