expose compiler_options on compile()

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
This commit is contained in:
Parker Schuh 2023-03-30 17:13:46 -07:00 committed by jax authors
parent d58c970f07
commit 0bb46856a8
6 changed files with 184 additions and 25 deletions

View File

@ -1287,10 +1287,14 @@ class PmapComputation(stages.XlaLowering):
return self._hlo
@profiler.annotate_function
def compile(self) -> PmapExecutable:
if self._executable is None:
self._executable = UnloadedPmapExecutable.from_hlo(
self._hlo, **self.compile_args)
def compile(self, compiler_options=None) -> PmapExecutable:
if self._executable is None or compiler_options is not None:
executable = UnloadedPmapExecutable.from_hlo(
self._hlo, **self.compile_args,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
return executable
return self._executable
@ -1317,7 +1321,8 @@ class UnloadedPmapExecutable:
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any):
keepalive: Any,
compiler_options=None):
devices = pci.devices
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):
@ -1374,6 +1379,7 @@ class UnloadedPmapExecutable:
num_partitions=parts.num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=use_spmd_partitioning,
env_options_overrides=compiler_options,
)
compile_options.parameter_is_tupled_arguments = tuple_args
@ -2801,20 +2807,27 @@ class MeshComputation(stages.XlaLowering):
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
def compile(self,
def compile(
self,
compiler_options=None,
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
_allow_compile_replicated: bool = True) -> MeshExecutable:
if self._executable is None:
_allow_compile_replicated: bool = True,
) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
if self.is_trivial:
self._executable = MeshExecutable.from_trivial_jaxpr(
executable = MeshExecutable.from_trivial_jaxpr(
**self.compile_args)
else:
self._executable = UnloadedMeshExecutable.from_hlo(
executable = UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated)
_allow_compile_replicated=_allow_compile_replicated,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
return executable
return self._executable
def cost_analysis(self) -> Dict[str, float]:
@ -2967,7 +2980,8 @@ class UnloadedMeshExecutable:
backend: xb.XlaBackend,
device_assignment: Sequence[xc.Device],
committed: bool,
pmap_nreps: int = 1
pmap_nreps: int = 1,
compiler_options=None
) -> MeshExecutable:
dev: np.ndarray
@ -2997,6 +3011,7 @@ class UnloadedMeshExecutable:
device_assignment=xla_device_assignment,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
)
if auto_spmd_lowering:
assert mesh is not None

View File

@ -145,7 +145,7 @@ class Executable(Protocol):
class Lowering(Protocol):
"""Protocol for lowerings, which a user-facing ``Lowered`` encapsulates."""
def compile(self) -> Executable:
def compile(self, compiler_options=None) -> Executable:
"""Compile and return a corresponding ``Executable``."""
raise NotImplementedError
@ -295,7 +295,7 @@ class XlaLowering(Lowering):
"""Return a StableHLO representation of this computation."""
raise NotImplementedError("must override")
def compile(self) -> Executable:
def compile(self, compiler_options=None) -> Executable:
raise NotImplementedError("must override")
def as_text(self, dialect: Optional[str] = None) -> str:
@ -583,24 +583,26 @@ class Lowered(Stage):
out_tree,
no_kwargs=no_kwargs)
def compile(self) -> Compiled:
def compile(self, compiler_options=None) -> Compiled:
"""Compile, returning a corresponding ``Compiled`` instance."""
from jax._src.interpreters import pxla
kw = {"compiler_options": compiler_options}
if isinstance(self._lowering, pxla.MeshComputation):
kw = dict(
kw.update(
_allow_propagation_to_outputs=[
pxla._is_unspecified(o)
for o in self._lowering.compile_args["out_shardings"]]
for o in self._lowering.compile_args["out_shardings"]
]
)
else:
kw = {}
return Compiled(
self._lowering.compile(**kw),
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
self.args_info,
self.out_tree,
no_kwargs=self._no_kwargs)
no_kwargs=self._no_kwargs,
)
def as_text(self, dialect: Optional[str] = None) -> str:
"""A human-readable text representation of this lowering.

View File

@ -28,7 +28,7 @@ import platform as py_platform
import threading
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings
from jax._src.lib import xla_extension_version
import numpy as np
from jax._src import lib
@ -93,7 +93,9 @@ def get_compile_options(
use_spmd_partitioning: bool = True,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape=[],
auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions:
auto_spmd_partitioning_mesh_ids=[],
env_options_overrides: Optional[Dict[str, str]] = None,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
@ -111,6 +113,7 @@ def get_compile_options(
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
env_options_overrides: dict of additional options parsed by the compiler
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
@ -147,6 +150,13 @@ def get_compile_options(
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
if env_options_overrides is not None:
if xla_extension_version >= 145:
compile_options.env_option_overrides = list(env_options_overrides.items())
else:
raise TypeError(
"`env_options_overrides` is only supported in later versions of jaxlib")
debug_options = compile_options.executable_build_options.debug_options
if lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path

View File

@ -23,6 +23,7 @@ from jax.experimental.compilation_cache.gfile_cache import GFileCache
from jax._src import path as pathlib
from jax._src.lib import xla_client
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_extension_version
_cache = None
@ -128,6 +129,9 @@ def _hash_computation(hash_obj, xla_computation):
hash_obj.update(scrubbed_hlo)
def _hash_compile_options(hash_obj, compile_options_obj):
if xla_extension_version >= 145:
expected_num_compile_options = 12
else:
expected_num_compile_options = 11
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
@ -152,6 +156,16 @@ def _hash_compile_options(hash_obj, compile_options_obj):
if compile_options_obj.device_assignment is not None:
hash_obj.update(compile_options_obj.device_assignment.serialize())
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
if xla_extension_version >= 145:
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
for kv in compile_options_obj.env_option_overrides:
_hash_string(hash_obj, kv[0])
if isinstance(kv[1], str):
_hash_string(hash_obj, kv[1])
elif isinstance(kv[1], bool):
_hash_bool(hash_obj, kv[1])
else:
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
def _hash_executable_build_options(hash_obj, executable_obj):
expected_options = 10

View File

@ -64,6 +64,8 @@ from jax import custom_derivatives as custom_derivatives_public
from jax._src import prng
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax import tree_util
from jax._src import linear_util as lu
@ -1168,6 +1170,64 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
f_jit = self.jit(f)
lowered = f_jit.lower(1.)
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
f_jit = self.jit(f)
lowered = f_jit.lower(1.)
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "is not a valid bool value.",
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_multiple(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
f_jit = self.jit(f)
lowered = f_jit.lower(1.)
l1 = lowered.compile()
l2 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": True})
l3 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": False})
# Ideally we could test that these objects are different only in
# that they respect the different options. Object identity is a
# heuristic proxy for that.
self.assertTrue(l1 is not l2)
self.assertTrue(l1 is not l3)
self.assertTrue(l2 is not l3)
# We should still error on invalid options after some valid compiles
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
def test_jit_enum_as_dict_keys_fails(self):
class E(enum.Enum):
A = 0

View File

@ -45,6 +45,8 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
from jax._src import xla_bridge
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import safe_map, safe_zip
from jax._src.interpreters import pxla
from jax.interpreters import xla
@ -332,6 +334,62 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_invalid(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "is not a valid bool value.",
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_multiple(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
l1 = lowered.compile()
l2 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": True})
l3 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": False})
# Ideally we could test that these objects are different only in
# that they respect the different options. Object identity is a
# heuristic proxy for that.
self.assertTrue(l1 is not l2)
self.assertTrue(l1 is not l3)
self.assertTrue(l2 is not l3)
# We should still error on invalid options after some valid compiles
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
def testLowerShapedArray(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)