Clarify docs on jax.lax.cond. (#3569)

This commit is contained in:
Malcolm Reynolds 2020-06-26 19:44:50 +01:00 committed by GitHub
parent 11caa21eca
commit 63ff6cb8e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -621,6 +621,17 @@ def cond(*args, **kwargs):
Pred must be a scalar type.
Note that true_fun/false_fun may not need to refer to an `operand` to compute
their result, but one must still be provided to the `cond` call and be
accepted by both the branch functions, e.g.:
jax.lax.cond(
get_predicate_value(),
lambda _: 23,
lambda _: 42,
operand=None)
Arguments:
pred: Boolean scalar type, indicating which branch function to
apply. Collections (list, tuple) are not supported.