diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6e49ab3d6..8b25ec33f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1895,7 +1895,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG if logger.isEnabledFor(log_priority): logger.log(log_priority, - "Compiling %s for with global shapes and types %s. " + "Compiling %s with global shapes and types %s. " "Argument mapping: %s.", fun_name, global_in_avals, in_shardings)