33 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
lenamartens
3134797968 Add checkify.debug_check which is a noop outside of checkify. 2022-12-14 11:15:34 +00:00
lenamartens
e4757e8410 Rewrite Checkify to support tracking different error types.
In general, behavior should remain the same and this is not a breaking
change.

There are some minor changes to the API:
  - checkify.ErrorCategory has changed type: it's no longer an Enum, but
    the JaxException type. These have not been exposed as part of the
    public API.
  - some attributes on Error have changed and made private
  - The raised error has changed type (JaxRuntimeError), and will have a
    different traceback (pointing to the origin of the error + where the
    error value was raised).
  - `checkify.check` now supports formating error message with variable
    size runtime info!

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-11-25 15:31:54 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Lena Martens
e187428a54 Restructure checkify files.
PiperOrigin-RevId: 441726310
2022-04-14 04:32:24 -07:00
Lena Martens
73f23705d0 Checkify: explicitly export public API, hide private symbols.
PiperOrigin-RevId: 429277551
2022-02-17 04:30:59 -08:00
Lena Martens
758c721605 Checkify: fix nd-error case when array only has 1 element. 2022-02-16 19:15:48 +00:00
Lena Martens
b15c7f609a Checkify: fix check_error of nd-error.
PiperOrigin-RevId: 428857813
2022-02-15 13:12:53 -08:00
Lena Martens
a4cacf5729 Checkify: handle named_call in process_call.
named_call does not specify donated_invars, this change handles this missing
param case.

For future reference: we might want to add a call_param_updater registry to define
how call params need to get updated wrt checkify, like eg. partial_eval/ad does.
2022-02-14 15:24:55 +00:00
Lena Martens
2ba8aec274 Checkify: Fix docstring formatting and polish enabled_errors sets. 2022-02-10 16:54:43 +00:00
Matthew Johnson
d9270b242d [checkify] add custom_vjp support 2022-02-09 12:31:16 -08:00
jax authors
b82ef91f42 Merge pull request #9509 from mattjj:checkify-custom-jvp
PiperOrigin-RevId: 427541020
2022-02-09 12:25:22 -08:00
Matthew Johnson
4ce749e681 [checkify] handle custom_jvp 2022-02-09 12:12:58 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
lenamartens
4d0db5d975 Fix build and address suggestion 2022-02-08 20:57:51 +00:00
Lena Martens
b2cf12aa7e
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-02-08 20:23:40 +00:00
Lena Martens
0042edb5f4 Checkify: rename some symbols and add some docstrings. 2022-02-08 17:40:04 +00:00
Matthew Johnson
9488c5ae72 checkify: fix and test post_process_call/map 2022-01-19 14:44:30 -08:00
Lena Martens
8ea85769ea Checkify: add way to disable categories of errors.
By default only user_asserts are lifted into the checked function.
2022-01-18 17:59:50 +00:00
Matthew Johnson
6850833c3a checkify: tweak some organization and names 2022-01-10 21:29:12 -08:00
Lena Martens
7b5b9cefbd Add scatter OOB error. 2022-01-07 17:22:34 +00:00
Lena Martens
bbd127f8fa Add division by zero check. 2022-01-04 19:17:08 +00:00
Matthew Johnson
9e28bd5f4e small changes to checkify
Co-authored-by: Lena Martens <lenamartens@google.com>
2021-12-22 12:24:00 -08:00
Lena Martens
03e0deac04 Add NaN checkify rule to all lax primitives. 2021-12-20 17:51:24 +00:00
Lena Martens
98a5461132 Make sure while_loop cond generates an error even if it returns False. 2021-12-16 21:50:45 +00:00
Lena Martens
0dc5a33a88 Add checkify rule for while_loop. 2021-12-08 19:34:20 +00:00
Lena Martens
bbf1a9ba97 Add checkify rule for scan. 2021-12-07 12:36:44 +00:00
Matthew Johnson
c1f71d17c0 generalize assert primitive, allow recharging 2021-12-02 14:35:23 -08:00
Matthew Johnson
768b076420 add an assert primitive
The assert primitive has an effectful API and so it can't be staged out;
it's only a trace-time primitive. It can be discharged to the functional
form.

We might want to have separate transforms for discharging errors and for
adding error checks. But right now they're just bundled together in the
checkify transform.
2021-12-02 11:33:56 -08:00
jax authors
7869a6cb75 Merge pull request #8753 from mattjj:checkify
PiperOrigin-RevId: 413513067
2021-12-01 14:34:17 -08:00
Matthew Johnson
659f8b794f add skeleton checkify transformation 2021-12-01 10:44:58 -08:00