Remove temporary flag for forcing arg tuplization of lowered functions.

PiperOrigin-RevId: 578910366
This commit is contained in:
Parker Schuh 2023-11-02 10:52:32 -07:00 committed by jax authors
parent 1f6264896d
commit c8b7c1b80b

View File

@ -55,14 +55,6 @@ from jax._src.sharding_impls import (
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)
# TODO(b/300274285): Remove when performance of non-tuple and tuple args
# is matched within each supported backend.
_JAX_FORCE_TUPLE_ARGS = config.DEFINE_bool(
"jax_force_tuple_args",
False,
help="Force tuplization of arguments to lowered functions.",
)
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
@ -278,10 +270,6 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
# do not have small bounds.
# TPU only needs a tuple for very long lists
if platform == "tpu":
if _JAX_FORCE_TUPLE_ARGS.value:
# TODO(b/300274285): Remove when performance of non-tuple and tuple args
# is matched within each supported backend.
return True
return num_args > 2000
else:
return False