mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make result_shardings
's documentation more accurate
PiperOrigin-RevId: 538880048
This commit is contained in:
parent
6374b73ce6
commit
03b06b6aa4
@ -812,7 +812,7 @@ def lower_jaxpr_to_fun(
|
||||
replaced with bool arrays of size [0].
|
||||
replicated_args: if present, annotates arguments as replicated.
|
||||
arg_shardings: sharding annotations for each argument (optional).
|
||||
result_shardings: sharding annotations for each argument (optional).
|
||||
result_shardings: sharding annotations for each result (optional).
|
||||
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
|
||||
parameters and return values to express sharding. If False, use
|
||||
hlo.custom_call operators with sharding annotations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user