mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 20:46:05 +00:00

Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes. Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on. PiperOrigin-RevId: 707128096