Make result_shardings's documentation more accurate

PiperOrigin-RevId: 538880048
This commit is contained in:
jax authors 2023-06-08 13:44:43 -07:00
parent 6374b73ce6
commit 03b06b6aa4

View File

@ -812,7 +812,7 @@ def lower_jaxpr_to_fun(
replaced with bool arrays of size [0].
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each result (optional).
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
parameters and return values to express sharding. If False, use
hlo.custom_call operators with sharding annotations.