Strip device_assignment on GPU platform.

This makes the hash invariant on a multi-process case.

PiperOrigin-RevId: 617093247
This commit is contained in:
jax authors 2024-03-19 01:49:55 -07:00
parent a4533010b0
commit 0b28a4b168
2 changed files with 29 additions and 2 deletions

View File

@ -90,7 +90,10 @@ def get(module: ir.Module,
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
("compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj, compile_options)),
hash_obj, compile_options,
# In case of GPU multi-process tasks we need to strip device
# assignment to use cache key as invariant between processes.
strip_device_assignment=(backend.platform == "gpu"))),
("accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
("compression",
@ -172,7 +175,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
_hash_platform(hash_obj, backend)
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
def _hash_serialized_compile_options(hash_obj, compile_options_obj,
strip_device_assignment=False):
# Do not mess with the original CompileOptions object since it is passed to
# the compiler. Create a deep copy for the purpose of cache key generation.
compile_options_copy = copy.deepcopy(compile_options_obj)
@ -211,6 +215,12 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
debug_options.xla_gpu_cuda_data_dir = ""
# LINT.ThenChange(:xla_flags)
if strip_device_assignment and compile_options_copy.device_assignment:
replica_count = compile_options_copy.device_assignment.replica_count()
computation_count = compile_options_copy.device_assignment.computation_count()
compile_options_copy.device_assignment = xla_client.DeviceAssignment.create(
np.ndarray([replica_count, computation_count])
)
return hash_obj.update(compile_options_copy.SerializeAsString())

View File

@ -155,6 +155,23 @@ class CacheKeyTest(jtu.JaxTestCase):
cache_key.get(computation2, devices, compile_options, backend),
)
def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options_1 = compiler.get_compile_options(
num_replicas=1, num_partitions=1, device_assignment=np.array([[0]])
)
compile_options_2 = compiler.get_compile_options(
num_replicas=1, num_partitions=1, device_assignment=np.array([[1]])
)
backend = xla_bridge.get_backend()
hash_1 = cache_key.get(computation, devices, compile_options_1, backend)
hash_2 = cache_key.get(computation, devices, compile_options_2, backend)
if backend.platform == "gpu":
self.assertEqual(hash_1, hash_2)
else:
self.assertNotEqual(hash_1, hash_2)
@parameterized.parameters([False, True])
def test_identical_computations_different_metadata(self, include_metadata):
f = lambda x, y: lax.mul(lax.add(x, y), 2)