Don't tuple arguments passed to XLA:CPU

This is not needed and tuples are being avoided when possible for new code.

This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu

PiperOrigin-RevId: 476032228
This commit is contained in:
Tres Popp 2022-09-22 01:28:45 -07:00 committed by jax authors
parent 405a2310ce
commit 640e15fe07

View File

@ -389,8 +389,11 @@ def log_elapsed_time(fmt: str):
def should_tuple_args(num_args: int, platform: str):
# pass long arg lists as tuple for TPU
if platform == "tpu":
# CPU does not need a tuple as it uses a buffer table
# TPU only needs a tuple for very long lists
if platform == "cpu":
return False
elif platform == "tpu":
return num_args > 2000
else:
return num_args > 100