mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions. Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before` In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses PiperOrigin-RevId: 660177423