mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Don't tuple arguments passed to XLA:CPU
This is not needed and tuples are being avoided when possible for new code. This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu PiperOrigin-RevId: 476032228
This commit is contained in:
parent
405a2310ce
commit
640e15fe07
@ -389,8 +389,11 @@ def log_elapsed_time(fmt: str):
|
||||
|
||||
|
||||
def should_tuple_args(num_args: int, platform: str):
|
||||
# pass long arg lists as tuple for TPU
|
||||
if platform == "tpu":
|
||||
# CPU does not need a tuple as it uses a buffer table
|
||||
# TPU only needs a tuple for very long lists
|
||||
if platform == "cpu":
|
||||
return False
|
||||
elif platform == "tpu":
|
||||
return num_args > 2000
|
||||
else:
|
||||
return num_args > 100
|
||||
|
Loading…
x
Reference in New Issue
Block a user