Yash Katariya 473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
..