In general, behavior should remain the same and this is not a breaking
change.
There are some minor changes to the API:
- checkify.ErrorCategory has changed type: it's no longer an Enum, but
the JaxException type. These have not been exposed as part of the
public API.
- some attributes on Error have changed and made private
- The raised error has changed type (JaxRuntimeError), and will have a
different traceback (pointing to the origin of the error + where the
error value was raised).
- `checkify.check` now supports formating error message with variable
size runtime info!
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
named_call does not specify donated_invars, this change handles this missing
param case.
For future reference: we might want to add a call_param_updater registry to define
how call params need to get updated wrt checkify, like eg. partial_eval/ad does.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.
Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.
PiperOrigin-RevId: 427517899
The assert primitive has an effectful API and so it can't be staged out;
it's only a trace-time primitive. It can be discharged to the functional
form.
We might want to have separate transforms for discharging errors and for
adding error checks. But right now they're just bundled together in the
checkify transform.