mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Strip device_assignment on GPU platform.
This makes the hash invariant on a multi-process case. PiperOrigin-RevId: 617093247
This commit is contained in:
parent
a4533010b0
commit
0b28a4b168
@ -90,7 +90,10 @@ def get(module: ir.Module,
|
||||
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
|
||||
("compile_options",
|
||||
lambda hash_obj: _hash_serialized_compile_options(
|
||||
hash_obj, compile_options)),
|
||||
hash_obj, compile_options,
|
||||
# In case of GPU multi-process tasks we need to strip device
|
||||
# assignment to use cache key as invariant between processes.
|
||||
strip_device_assignment=(backend.platform == "gpu"))),
|
||||
("accelerator_config",
|
||||
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
|
||||
("compression",
|
||||
@ -172,7 +175,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
|
||||
_hash_platform(hash_obj, backend)
|
||||
|
||||
|
||||
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
|
||||
def _hash_serialized_compile_options(hash_obj, compile_options_obj,
|
||||
strip_device_assignment=False):
|
||||
# Do not mess with the original CompileOptions object since it is passed to
|
||||
# the compiler. Create a deep copy for the purpose of cache key generation.
|
||||
compile_options_copy = copy.deepcopy(compile_options_obj)
|
||||
@ -211,6 +215,12 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
|
||||
debug_options.xla_gpu_cuda_data_dir = ""
|
||||
# LINT.ThenChange(:xla_flags)
|
||||
|
||||
if strip_device_assignment and compile_options_copy.device_assignment:
|
||||
replica_count = compile_options_copy.device_assignment.replica_count()
|
||||
computation_count = compile_options_copy.device_assignment.computation_count()
|
||||
compile_options_copy.device_assignment = xla_client.DeviceAssignment.create(
|
||||
np.ndarray([replica_count, computation_count])
|
||||
)
|
||||
return hash_obj.update(compile_options_copy.SerializeAsString())
|
||||
|
||||
|
||||
|
@ -155,6 +155,23 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
cache_key.get(computation2, devices, compile_options, backend),
|
||||
)
|
||||
|
||||
def test_different_device_assignment(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options_1 = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, device_assignment=np.array([[0]])
|
||||
)
|
||||
compile_options_2 = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, device_assignment=np.array([[1]])
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
hash_1 = cache_key.get(computation, devices, compile_options_1, backend)
|
||||
hash_2 = cache_key.get(computation, devices, compile_options_2, backend)
|
||||
if backend.platform == "gpu":
|
||||
self.assertEqual(hash_1, hash_2)
|
||||
else:
|
||||
self.assertNotEqual(hash_1, hash_2)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_identical_computations_different_metadata(self, include_metadata):
|
||||
f = lambda x, y: lax.mul(lax.add(x, y), 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user