Squashed commit of the following:

commit 1abe9559d1ba7a6ec4e2081c52ebdf0eef6b5e56
Merge: 1e1cc3e07 1b2ba9d1c
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:42:04 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 1e1cc3e0733cca77e2f1ee928f96edcf63f673cf
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:37:22 2024 -0700

    added comment

commit 631c41fcbdbbac864fadd72c984b07801872f218
Merge: b93b52f27 ce3ea109a
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 21 08:54:00 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit b93b52f27aacf7f58eba914a91810b5d0ac06316
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Aug 20 19:00:08 2024 -0700

    remove stray breakpoint

commit 9ee0842ea98557bcdca0ecfd9031a8ea5274e9a4
Merge: 799e359a5 be53ee10b
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 18:09:19 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 799e359a522acd1a83dd7868a3a9278e189664f6
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 17:31:27 2024 -0700

    added tests and minor changes

    fix

commit c973004493f633526b14a6b5acb3afe50d58c977
Merge: 5900969cc b3924da2a
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:28:59 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 5900969cc9178bf3629baa49c6a300446bf6d4a9
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:20:52 2024 -0700

    minor edits

commit a7cc85a1cb8ddd07b783cc538f25c56f5fb78543
Merge: 89b876270 091eba195
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:17:13 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 89b876270bf5f16dc10c2f8700d69715752ca184
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:11:39 2024 -0700

    native IR traversal instead of string manipulation

commit 3b161a414d9579c50e1902047dbd45bac840a767
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 20:12:30 2024 -0700

    longer match string and string search optimization

commit 224ee59d2115ec43000105b97bd6e73c40777ab9
Merge: c7664aa61 6a7822a73
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:08:29 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit c7664aa61fa9cec55fba9d5ee1d3ffb146a4c2b1
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:07:04 2024 -0700

    remove custom partitioning ptr from pre-compiled hlo during cache key computation

linter fixes

more linter fixes

more linter fixes

alternate imports
This commit is contained in:
Keshav 2024-09-10 10:02:05 -07:00
parent 9fa0164ad2
commit 7c660c4ea0
4 changed files with 83 additions and 2 deletions

View File

@ -83,7 +83,8 @@ def get(module: ir.Module,
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
"""
entries = [
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
("computation",
lambda hash_obj: _hash_computation(hash_obj, module)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
)
def _remove_custom_partitioning_ptr(m: ir.Module):
"""
Removes custom_partitioning callback pointer from precompiled IR.
Python function pointers are not deterministic across executions.
"""
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
if (op.name == "stablehlo.custom_call" and
op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
return ir.WalkResult.ADVANCE
m.operation.walk(_update_bc_attribute)
return m
def _serialize_ir(m: ir.Module) -> bytes:
output = io.BytesIO()
if config.remove_custom_partitioning_ptr_from_cache_key.value:
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
m.operation.clone()))
m.operation.write_bytecode(file=output)
return output.getvalue()

View File

@ -265,7 +265,9 @@ def put_executable_and_time(
cache.put(cache_key, executable_and_time)
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
def get_cache_key(module: ir.Module,
devices: np.ndarray,
compile_options,
backend) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib")

View File

@ -1347,6 +1347,16 @@ compilation_cache_max_size = int_state(
'size to grow indefinitely.'),
)
remove_custom_partitioning_ptr_from_cache_key = bool_state(
name='jax_remove_custom_partitioning_ptr_from_cache_key',
default=False,
help=('If set to True, remove the custom partitioning pointer '
'present in the precompiled stableHLO before hashing '
'during cache key computation. This is a potentially '
'unsafe flag to set and only users who are sure of '
'what they are trying to achieve should set it.'),
)
default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],

View File

@ -14,8 +14,10 @@
import hashlib
import os
import re
import sys
import unittest
from typing import cast as type_cast
import numpy as np
@ -29,6 +31,11 @@ from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src.custom_partitioning import custom_partitioning
config.parse_flags_with_absl()
@ -155,6 +162,49 @@ class CacheKeyTest(jtu.JaxTestCase):
cache_key.get(computation2, devices, compile_options, backend),
)
def test_custom_partitioning_ptr_removal(self):
def _partition(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec)
return mesh, jax.numpy.add, result_shardings, arg_shardings
def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return NamedSharding(mesh, arg_shapes[0].sharding.spec)
@custom_partitioning
def _cp_add(x, y):
return jax.numpy.add(x, y)
_cp_add.def_partition(
infer_sharding_from_operands=_infer_sharding_from_operands,
partition=_partition)
devices = np.asarray(jax.devices())
with Mesh(devices, ('x',)) as m:
computation = jax.jit(
_cp_add,
in_shardings=(NamedSharding(m, P('x')),
NamedSharding(m, P('x'))),
out_shardings=NamedSharding(m, P('x'))
).lower(
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
).compiler_ir()
pattern = (
r'stablehlo\.custom_call @CustomSPMDPartitioning\('
r'(.*?)\) \{'
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
r'\}'
)
with config.remove_custom_partitioning_ptr_from_cache_key(True):
with computation.context:
updated_module = cache_key._remove_custom_partitioning_ptr(
type_cast(ir.Module, computation.operation.clone()))
bcs = [match[2] for
match in re.findall(pattern, str(updated_module), re.DOTALL)]
for bc in bcs:
self.assertEqual(bc, "REMOVED")
def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])