mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make result_shardings
's documentation more accurate
PiperOrigin-RevId: 538880048
This commit is contained in:
parent
6374b73ce6
commit
03b06b6aa4
@ -812,7 +812,7 @@ def lower_jaxpr_to_fun(
|
|||||||
replaced with bool arrays of size [0].
|
replaced with bool arrays of size [0].
|
||||||
replicated_args: if present, annotates arguments as replicated.
|
replicated_args: if present, annotates arguments as replicated.
|
||||||
arg_shardings: sharding annotations for each argument (optional).
|
arg_shardings: sharding annotations for each argument (optional).
|
||||||
result_shardings: sharding annotations for each argument (optional).
|
result_shardings: sharding annotations for each result (optional).
|
||||||
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
|
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
|
||||||
parameters and return values to express sharding. If False, use
|
parameters and return values to express sharding. If False, use
|
||||||
hlo.custom_call operators with sharding annotations.
|
hlo.custom_call operators with sharding annotations.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user