Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.
Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`
In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses
PiperOrigin-RevId: 660177423
The plugin is released and the flag is no longer needed.
Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.
PiperOrigin-RevId: 660059432
1. Each process now corresponds to an SM, showing how many blocks
are executing concurrently.
2. The timeline now accounts for the start offset of each block,
instead of aligning them together. This makes a lot more sense in
the SM view.
3. We now use inline PTX to emit profiler events. This sometimes slightly
pessimizes code generation, but allows us to predicate out write on
all threads other than the leader of each warpgroup, improving the
trace quality.
4. We make sure each trace is monotonic. I can't explain why but the clocks
can behave very weirdly, potentially due to rescheduling on the SASS level.
We now fix up all backward movements and emit a warning if big shifts have
been detected.
PiperOrigin-RevId: 659911268
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`.
Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time.
* Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis.
PiperOrigin-RevId: 659560330
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.
Added some more information to the `if debug` prints.
Set the MLIR module names so that the debug dumps are named properly.
I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.
PiperOrigin-RevId: 659506675
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 659492696
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.
PiperOrigin-RevId: 659483450