mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add temporary flag for forcing arg tuplization of lowered functions.
PiperOrigin-RevId: 569814851
This commit is contained in:
parent
bf46b7427f
commit
0ae2a63426
@ -55,6 +55,14 @@ from jax._src.sharding_impls import (
|
||||
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)
|
||||
|
||||
|
||||
# TODO(b/300274285): Remove when performance of non-tuple and tuple args
|
||||
# is matched within each supported backend.
|
||||
_JAX_FORCE_TUPLE_ARGS = config.DEFINE_bool(
|
||||
"jax_force_tuple_args",
|
||||
False,
|
||||
help="Force tuplization of arguments to lowered functions.",
|
||||
)
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
@ -270,6 +278,10 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
|
||||
# do not have small bounds.
|
||||
# TPU only needs a tuple for very long lists
|
||||
if platform == "tpu":
|
||||
if _JAX_FORCE_TUPLE_ARGS.value:
|
||||
# TODO(b/300274285): Remove when performance of non-tuple and tuple args
|
||||
# is matched within each supported backend.
|
||||
return True
|
||||
return num_args > 2000
|
||||
else:
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user