diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b12620748..e6244a3b2 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -389,8 +389,11 @@ def log_elapsed_time(fmt: str): def should_tuple_args(num_args: int, platform: str): - # pass long arg lists as tuple for TPU - if platform == "tpu": + # CPU does not need a tuple as it uses a buffer table + # TPU only needs a tuple for very long lists + if platform == "cpu": + return False + elif platform == "tpu": return num_args > 2000 else: return num_args > 100