mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Clarify docs on jax.lax.cond. (#3569)
This commit is contained in:
parent
11caa21eca
commit
63ff6cb8e9
@ -621,6 +621,17 @@ def cond(*args, **kwargs):
|
||||
|
||||
Pred must be a scalar type.
|
||||
|
||||
Note that true_fun/false_fun may not need to refer to an `operand` to compute
|
||||
their result, but one must still be provided to the `cond` call and be
|
||||
accepted by both the branch functions, e.g.:
|
||||
|
||||
jax.lax.cond(
|
||||
get_predicate_value(),
|
||||
lambda _: 23,
|
||||
lambda _: 42,
|
||||
operand=None)
|
||||
|
||||
|
||||
Arguments:
|
||||
pred: Boolean scalar type, indicating which branch function to
|
||||
apply. Collections (list, tuple) are not supported.
|
||||
|
Loading…
x
Reference in New Issue
Block a user