mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Squashed commit of the following:
commit 1abe9559d1ba7a6ec4e2081c52ebdf0eef6b5e56 Merge: 1e1cc3e07 1b2ba9d1c Author: Keshav <keshavb@nvidia.com> Date: Tue Sep 10 09:42:04 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 1e1cc3e0733cca77e2f1ee928f96edcf63f673cf Author: Keshav <keshavb@nvidia.com> Date: Tue Sep 10 09:37:22 2024 -0700 added comment commit 631c41fcbdbbac864fadd72c984b07801872f218 Merge: b93b52f27 ce3ea109a Author: Keshav <keshavb@nvidia.com> Date: Wed Aug 21 08:54:00 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit b93b52f27aacf7f58eba914a91810b5d0ac06316 Author: Keshav <keshavb@nvidia.com> Date: Tue Aug 20 19:00:08 2024 -0700 remove stray breakpoint commit 9ee0842ea98557bcdca0ecfd9031a8ea5274e9a4 Merge: 799e359a5 be53ee10b Author: Keshav <keshavb@nvidia.com> Date: Wed Aug 7 18:09:19 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 799e359a522acd1a83dd7868a3a9278e189664f6 Author: Keshav <keshavb@nvidia.com> Date: Wed Aug 7 17:31:27 2024 -0700 added tests and minor changes fix commit c973004493f633526b14a6b5acb3afe50d58c977 Merge: 5900969cc b3924da2a Author: Keshav <keshavb@nvidia.com> Date: Thu Aug 1 11:28:59 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 5900969cc9178bf3629baa49c6a300446bf6d4a9 Author: Keshav <keshavb@nvidia.com> Date: Thu Aug 1 11:20:52 2024 -0700 minor edits commit a7cc85a1cb8ddd07b783cc538f25c56f5fb78543 Merge: 89b876270 091eba195 Author: Keshav <keshavb@nvidia.com> Date: Mon Jul 29 14:17:13 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 89b876270bf5f16dc10c2f8700d69715752ca184 Author: Keshav <keshavb@nvidia.com> Date: Mon Jul 29 14:11:39 2024 -0700 native IR traversal instead of string manipulation commit 3b161a414d9579c50e1902047dbd45bac840a767 Author: Keshav <keshavb@nvidia.com> Date: Sun Jul 28 20:12:30 2024 -0700 longer match string and string search optimization commit 224ee59d2115ec43000105b97bd6e73c40777ab9 Merge: c7664aa61 6a7822a73 Author: Keshav <keshavb@nvidia.com> Date: Sun Jul 28 17:08:29 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit c7664aa61fa9cec55fba9d5ee1d3ffb146a4c2b1 Author: Keshav <keshavb@nvidia.com> Date: Sun Jul 28 17:07:04 2024 -0700 remove custom partitioning ptr from pre-compiled hlo during cache key computation linter fixes more linter fixes more linter fixes alternate imports
This commit is contained in:
parent
9fa0164ad2
commit
7c660c4ea0
@ -83,7 +83,8 @@ def get(module: ir.Module,
|
||||
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
|
||||
"""
|
||||
entries = [
|
||||
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
|
||||
("computation",
|
||||
lambda hash_obj: _hash_computation(hash_obj, module)),
|
||||
("jax_lib version",
|
||||
lambda hash_obj: hash_obj.update(
|
||||
bytes(jaxlib_version_str.encode("utf-8")))),
|
||||
@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
)
|
||||
|
||||
|
||||
def _remove_custom_partitioning_ptr(m: ir.Module):
|
||||
"""
|
||||
Removes custom_partitioning callback pointer from precompiled IR.
|
||||
Python function pointers are not deterministic across executions.
|
||||
"""
|
||||
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
|
||||
if (op.name == "stablehlo.custom_call" and
|
||||
op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
|
||||
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
m.operation.walk(_update_bc_attribute)
|
||||
return m
|
||||
|
||||
|
||||
def _serialize_ir(m: ir.Module) -> bytes:
|
||||
output = io.BytesIO()
|
||||
if config.remove_custom_partitioning_ptr_from_cache_key.value:
|
||||
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
|
||||
m.operation.clone()))
|
||||
m.operation.write_bytecode(file=output)
|
||||
return output.getvalue()
|
||||
|
||||
|
@ -265,7 +265,9 @@ def put_executable_and_time(
|
||||
cache.put(cache_key, executable_and_time)
|
||||
|
||||
|
||||
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
|
||||
def get_cache_key(module: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options,
|
||||
backend) -> str:
|
||||
return cache_key.get(module, devices, compile_options, backend,
|
||||
"zstandard" if zstandard is not None else "zlib")
|
||||
|
@ -1347,6 +1347,16 @@ compilation_cache_max_size = int_state(
|
||||
'size to grow indefinitely.'),
|
||||
)
|
||||
|
||||
remove_custom_partitioning_ptr_from_cache_key = bool_state(
|
||||
name='jax_remove_custom_partitioning_ptr_from_cache_key',
|
||||
default=False,
|
||||
help=('If set to True, remove the custom partitioning pointer '
|
||||
'present in the precompiled stableHLO before hashing '
|
||||
'during cache key computation. This is a potentially '
|
||||
'unsafe flag to set and only users who are sure of '
|
||||
'what they are trying to achieve should set it.'),
|
||||
)
|
||||
|
||||
default_dtype_bits = enum_state(
|
||||
name='jax_default_dtype_bits',
|
||||
enum_values=['32', '64'],
|
||||
|
@ -14,8 +14,10 @@
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from typing import cast as type_cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -29,6 +31,11 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.mesh import Mesh
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.custom_partitioning import custom_partitioning
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -155,6 +162,49 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
cache_key.get(computation2, devices, compile_options, backend),
|
||||
)
|
||||
|
||||
def test_custom_partitioning_ptr_removal(self):
|
||||
def _partition(mesh, arg_shapes, result_shape):
|
||||
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
|
||||
result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec)
|
||||
return mesh, jax.numpy.add, result_shardings, arg_shardings
|
||||
|
||||
def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
||||
return NamedSharding(mesh, arg_shapes[0].sharding.spec)
|
||||
|
||||
@custom_partitioning
|
||||
def _cp_add(x, y):
|
||||
return jax.numpy.add(x, y)
|
||||
|
||||
_cp_add.def_partition(
|
||||
infer_sharding_from_operands=_infer_sharding_from_operands,
|
||||
partition=_partition)
|
||||
|
||||
devices = np.asarray(jax.devices())
|
||||
with Mesh(devices, ('x',)) as m:
|
||||
computation = jax.jit(
|
||||
_cp_add,
|
||||
in_shardings=(NamedSharding(m, P('x')),
|
||||
NamedSharding(m, P('x'))),
|
||||
out_shardings=NamedSharding(m, P('x'))
|
||||
).lower(
|
||||
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
|
||||
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
|
||||
).compiler_ir()
|
||||
pattern = (
|
||||
r'stablehlo\.custom_call @CustomSPMDPartitioning\('
|
||||
r'(.*?)\) \{'
|
||||
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
|
||||
r'\}'
|
||||
)
|
||||
with config.remove_custom_partitioning_ptr_from_cache_key(True):
|
||||
with computation.context:
|
||||
updated_module = cache_key._remove_custom_partitioning_ptr(
|
||||
type_cast(ir.Module, computation.operation.clone()))
|
||||
bcs = [match[2] for
|
||||
match in re.findall(pattern, str(updated_module), re.DOTALL)]
|
||||
for bc in bcs:
|
||||
self.assertEqual(bc, "REMOVED")
|
||||
|
||||
def test_different_device_assignment(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
|
Loading…
x
Reference in New Issue
Block a user