Enable to set fdo_profile through XLA python client.

PiperOrigin-RevId: 547303330
This commit is contained in:
Tao Wang 2023-07-11 14:47:04 -07:00 committed by jax authors
parent a1a01dd86e
commit 6eb3096461
4 changed files with 30 additions and 9 deletions

View File

@ -34,6 +34,7 @@ from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
@ -240,7 +241,10 @@ def _hash_compile_options(hash_obj, compile_options_obj):
def _hash_executable_build_options(hash_obj, executable_obj):
expected_options = 10
if xla_extension_version > 165:
expected_options = 11
else:
expected_options = 10
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
# built-ins.
@ -269,6 +273,8 @@ def _hash_executable_build_options(hash_obj, executable_obj):
_hash_bool_list(
hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output
)
if xla_extension_version > 165 and executable_obj.fdo_profile is not None:
_hash_string(hash_obj, executable_obj.fdo_profile)
def _hash_debug_options(hash_obj, debug_obj):

View File

@ -94,6 +94,7 @@ flags.DEFINE_string(
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
def get_compile_options(
num_replicas: int,
num_partitions: int,
@ -103,6 +104,7 @@ def get_compile_options(
auto_spmd_partitioning_mesh_shape=[],
auto_spmd_partitioning_mesh_ids=[],
env_options_overrides: Optional[dict[str, str]] = None,
fdo_profile: Optional[bytes] = None,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
@ -122,6 +124,8 @@ def get_compile_options(
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
env_options_overrides: dict of additional options parsed by the compiler
fdo_profile: Optional profile for feedback-directed optimization passed to
XLA.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
@ -129,6 +133,8 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if xla_extension_version > 165 and fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids

View File

@ -20,17 +20,14 @@ import random
import sys
import tempfile
import unittest
from unittest import mock, SkipTest
from unittest import SkipTest, mock
import warnings
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import jit, lax, pmap
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax._src import compilation_cache as cc
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -40,11 +37,12 @@ from jax._src.config import (
raise_persistent_cache_errors,
)
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
import numpy as np
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -482,6 +480,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
compile_options.executable_build_options.device_assignment = (
device_assignment
)
if xla_extension_version > 165:
compile_options.executable_build_options.fdo_profile = b"test_profile"
return compile_options
def get_hashed_value(self, hash_function, hash_function_input):

View File

@ -50,6 +50,15 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertEqual(compile_options.device_assignment.__repr__(),
expected_device_assignment)
def test_set_fdo_profile(self):
if xla_extension_version > 166:
compile_options = xb.get_compile_options(
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
)
self.assertEqual(
compile_options.executable_build_options.fdo_profile, "test_profile"
)
def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))