mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Enable to set fdo_profile through XLA python client.
PiperOrigin-RevId: 547303330
This commit is contained in:
parent
a1a01dd86e
commit
6eb3096461
@ -34,6 +34,7 @@ from jax._src import path as pathlib
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
from jax._src.gfile_cache import GFileCache
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager as pm
|
||||
@ -240,7 +241,10 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
expected_options = 10
|
||||
if xla_extension_version > 165:
|
||||
expected_options = 11
|
||||
else:
|
||||
expected_options = 10
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
@ -269,6 +273,8 @@ def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
_hash_bool_list(
|
||||
hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output
|
||||
)
|
||||
if xla_extension_version > 165 and executable_obj.fdo_profile is not None:
|
||||
_hash_string(hash_obj, executable_obj.fdo_profile)
|
||||
|
||||
|
||||
def _hash_debug_options(hash_obj, debug_obj):
|
||||
|
@ -94,6 +94,7 @@ flags.DEFINE_string(
|
||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||
'comma-separate list of integer device IDs.')
|
||||
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
@ -103,6 +104,7 @@ def get_compile_options(
|
||||
auto_spmd_partitioning_mesh_shape=[],
|
||||
auto_spmd_partitioning_mesh_ids=[],
|
||||
env_options_overrides: Optional[dict[str, str]] = None,
|
||||
fdo_profile: Optional[bytes] = None,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
@ -122,6 +124,8 @@ def get_compile_options(
|
||||
auto_spmd_partitioning_mesh_ids: device ids used to create
|
||||
auto_spmd_partitioning search space.
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
fdo_profile: Optional profile for feedback-directed optimization passed to
|
||||
XLA.
|
||||
"""
|
||||
compile_options = xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
@ -129,6 +133,8 @@ def get_compile_options(
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
||||
if xla_extension_version > 165 and fdo_profile is not None:
|
||||
build_options.fdo_profile = fdo_profile
|
||||
if use_auto_spmd_partitioning:
|
||||
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
|
||||
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
|
||||
|
@ -20,17 +20,14 @@ import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock, SkipTest
|
||||
from unittest import SkipTest, mock
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import jit, lax, pmap
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -40,11 +37,12 @@ from jax._src.config import (
|
||||
raise_persistent_cache_errors,
|
||||
)
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@ -482,6 +480,8 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
compile_options.executable_build_options.device_assignment = (
|
||||
device_assignment
|
||||
)
|
||||
if xla_extension_version > 165:
|
||||
compile_options.executable_build_options.fdo_profile = b"test_profile"
|
||||
return compile_options
|
||||
|
||||
def get_hashed_value(self, hash_function, hash_function_input):
|
||||
|
@ -50,6 +50,15 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(compile_options.device_assignment.__repr__(),
|
||||
expected_device_assignment)
|
||||
|
||||
def test_set_fdo_profile(self):
|
||||
if xla_extension_version > 166:
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
||||
)
|
||||
self.assertEqual(
|
||||
compile_options.executable_build_options.fdo_profile, "test_profile"
|
||||
)
|
||||
|
||||
def test_parameter_replication_default(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
||||
|
Loading…
x
Reference in New Issue
Block a user