From 640e15fe070a887197143b76a19a3dced816c8df Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Thu, 22 Sep 2022 01:28:45 -0700 Subject: [PATCH] Don't tuple arguments passed to XLA:CPU This is not needed and tuples are being avoided when possible for new code. This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu PiperOrigin-RevId: 476032228 --- jax/_src/dispatch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b12620748..e6244a3b2 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -389,8 +389,11 @@ def log_elapsed_time(fmt: str): def should_tuple_args(num_args: int, platform: str): - # pass long arg lists as tuple for TPU - if platform == "tpu": + # CPU does not need a tuple as it uses a buffer table + # TPU only needs a tuple for very long lists + if platform == "cpu": + return False + elif platform == "tpu": return num_args > 2000 else: return num_args > 100