Enable passing fdo_profile in compiler_options in pxla.py

PiperOrigin-RevId: 549109629
This commit is contained in:
Tao Wang 2023-07-18 14:17:56 -07:00 committed by jax authors
parent 579808d986
commit b7686f41aa
3 changed files with 95 additions and 0 deletions

View File

@ -2537,6 +2537,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
else:
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
fdo_profile = (None if compiler_options is None else
compiler_options.pop("fdo_profile", None))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
@ -2544,6 +2547,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
fdo_profile=fdo_profile,
)
opts = compile_options.executable_build_options

View File

@ -195,6 +195,23 @@ jax_test(
],
)
jax_test(
name = "pgle_test",
srcs = ["pgle_test.py"],
disable_backends = [
"cpu",
"tpu",
],
env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"},
tags = [
"config-cuda-only",
"multiaccelerator",
],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "array_test",
srcs = ["array_test.py"],

74
tests/pgle_test.py Normal file
View File

@ -0,0 +1,74 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import glob
import logging
import math
import os
import tempfile
from absl.testing import absltest
import jax
from jax import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
def testPassingFDOProfile(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, P('x',)),
out_shardings=NamedSharding(mesh, P('x',)),
)
def f(x, y):
z = x @ y
return z @ y
shape = (8, 8)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
compiled(x, y)
jax.profiler.stop_trace()
directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/'))
directories = [d for d in directories if os.path.isdir(d)]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
if jtu.device_under_test() == 'gpu' and jtu.is_device_cuda():
self.assertIn(b'custom', fdo_profile)
logging.info('fdo_profile: %s', fdo_profile)
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())