diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2957fda8f..9b664a28d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -812,7 +812,7 @@ def lower_jaxpr_to_fun( replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). - result_shardings: sharding annotations for each argument (optional). + result_shardings: sharding annotations for each result (optional). use_sharding_annotations: if True, use "mhlo.sharding" annotations on parameters and return values to express sharding. If False, use hlo.custom_call operators with sharding annotations.