mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
expose compiler_options
on compile()
Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 520782460
This commit is contained in:
parent
d58c970f07
commit
0bb46856a8
@ -1287,10 +1287,14 @@ class PmapComputation(stages.XlaLowering):
|
||||
return self._hlo
|
||||
|
||||
@profiler.annotate_function
|
||||
def compile(self) -> PmapExecutable:
|
||||
if self._executable is None:
|
||||
self._executable = UnloadedPmapExecutable.from_hlo(
|
||||
self._hlo, **self.compile_args)
|
||||
def compile(self, compiler_options=None) -> PmapExecutable:
|
||||
if self._executable is None or compiler_options is not None:
|
||||
executable = UnloadedPmapExecutable.from_hlo(
|
||||
self._hlo, **self.compile_args,
|
||||
compiler_options=compiler_options)
|
||||
if compiler_options is None:
|
||||
self._executable = executable
|
||||
return executable
|
||||
return self._executable
|
||||
|
||||
|
||||
@ -1317,7 +1321,8 @@ class UnloadedPmapExecutable:
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
keepalive: Any):
|
||||
keepalive: Any,
|
||||
compiler_options=None):
|
||||
devices = pci.devices
|
||||
if devices is None:
|
||||
if shards.num_global_shards > xb.device_count(pci.backend):
|
||||
@ -1374,6 +1379,7 @@ class UnloadedPmapExecutable:
|
||||
num_partitions=parts.num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
use_spmd_partitioning=use_spmd_partitioning,
|
||||
env_options_overrides=compiler_options,
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
|
||||
@ -2801,20 +2807,27 @@ class MeshComputation(stages.XlaLowering):
|
||||
raise ValueError("A trivial computation has no StableHLO")
|
||||
return self._hlo
|
||||
|
||||
def compile(self,
|
||||
def compile(
|
||||
self,
|
||||
compiler_options=None,
|
||||
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
|
||||
_allow_compile_replicated: bool = True) -> MeshExecutable:
|
||||
if self._executable is None:
|
||||
_allow_compile_replicated: bool = True,
|
||||
) -> MeshExecutable:
|
||||
if self._executable is None or compiler_options is not None:
|
||||
if self.is_trivial:
|
||||
self._executable = MeshExecutable.from_trivial_jaxpr(
|
||||
executable = MeshExecutable.from_trivial_jaxpr(
|
||||
**self.compile_args)
|
||||
else:
|
||||
self._executable = UnloadedMeshExecutable.from_hlo(
|
||||
executable = UnloadedMeshExecutable.from_hlo(
|
||||
self._name,
|
||||
self._hlo,
|
||||
**self.compile_args,
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
||||
_allow_compile_replicated=_allow_compile_replicated)
|
||||
_allow_compile_replicated=_allow_compile_replicated,
|
||||
compiler_options=compiler_options)
|
||||
if compiler_options is None:
|
||||
self._executable = executable
|
||||
return executable
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
@ -2967,7 +2980,8 @@ class UnloadedMeshExecutable:
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1
|
||||
pmap_nreps: int = 1,
|
||||
compiler_options=None
|
||||
) -> MeshExecutable:
|
||||
|
||||
dev: np.ndarray
|
||||
@ -2997,6 +3011,7 @@ class UnloadedMeshExecutable:
|
||||
device_assignment=xla_device_assignment,
|
||||
use_spmd_partitioning=spmd_lowering,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
env_options_overrides=compiler_options,
|
||||
)
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
|
@ -145,7 +145,7 @@ class Executable(Protocol):
|
||||
class Lowering(Protocol):
|
||||
"""Protocol for lowerings, which a user-facing ``Lowered`` encapsulates."""
|
||||
|
||||
def compile(self) -> Executable:
|
||||
def compile(self, compiler_options=None) -> Executable:
|
||||
"""Compile and return a corresponding ``Executable``."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -295,7 +295,7 @@ class XlaLowering(Lowering):
|
||||
"""Return a StableHLO representation of this computation."""
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def compile(self) -> Executable:
|
||||
def compile(self, compiler_options=None) -> Executable:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def as_text(self, dialect: Optional[str] = None) -> str:
|
||||
@ -583,24 +583,26 @@ class Lowered(Stage):
|
||||
out_tree,
|
||||
no_kwargs=no_kwargs)
|
||||
|
||||
def compile(self) -> Compiled:
|
||||
def compile(self, compiler_options=None) -> Compiled:
|
||||
"""Compile, returning a corresponding ``Compiled`` instance."""
|
||||
from jax._src.interpreters import pxla
|
||||
|
||||
kw = {"compiler_options": compiler_options}
|
||||
|
||||
if isinstance(self._lowering, pxla.MeshComputation):
|
||||
kw = dict(
|
||||
kw.update(
|
||||
_allow_propagation_to_outputs=[
|
||||
pxla._is_unspecified(o)
|
||||
for o in self._lowering.compile_args["out_shardings"]]
|
||||
for o in self._lowering.compile_args["out_shardings"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
kw = {}
|
||||
|
||||
return Compiled(
|
||||
self._lowering.compile(**kw),
|
||||
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
|
||||
self.args_info,
|
||||
self.out_tree,
|
||||
no_kwargs=self._no_kwargs)
|
||||
no_kwargs=self._no_kwargs,
|
||||
)
|
||||
|
||||
def as_text(self, dialect: Optional[str] = None) -> str:
|
||||
"""A human-readable text representation of this lowering.
|
||||
|
@ -28,7 +28,7 @@ import platform as py_platform
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lib
|
||||
@ -93,7 +93,9 @@ def get_compile_options(
|
||||
use_spmd_partitioning: bool = True,
|
||||
use_auto_spmd_partitioning: bool = False,
|
||||
auto_spmd_partitioning_mesh_shape=[],
|
||||
auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions:
|
||||
auto_spmd_partitioning_mesh_ids=[],
|
||||
env_options_overrides: Optional[Dict[str, str]] = None,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
Args:
|
||||
@ -111,6 +113,7 @@ def get_compile_options(
|
||||
auto_spmd_partitioning search space.
|
||||
auto_spmd_partitioning_mesh_ids: device ids used to create
|
||||
auto_spmd_partitioning search space.
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
"""
|
||||
compile_options = xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
@ -147,6 +150,13 @@ def get_compile_options(
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if env_options_overrides is not None:
|
||||
if xla_extension_version >= 145:
|
||||
compile_options.env_option_overrides = list(env_options_overrides.items())
|
||||
else:
|
||||
raise TypeError(
|
||||
"`env_options_overrides` is only supported in later versions of jaxlib")
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
||||
|
@ -23,6 +23,7 @@ from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
_cache = None
|
||||
|
||||
@ -128,6 +129,9 @@ def _hash_computation(hash_obj, xla_computation):
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_extension_version >= 145:
|
||||
expected_num_compile_options = 12
|
||||
else:
|
||||
expected_num_compile_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
@ -152,6 +156,16 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if compile_options_obj.device_assignment is not None:
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
|
||||
if xla_extension_version >= 145:
|
||||
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
|
||||
for kv in compile_options_obj.env_option_overrides:
|
||||
_hash_string(hash_obj, kv[0])
|
||||
if isinstance(kv[1], str):
|
||||
_hash_string(hash_obj, kv[1])
|
||||
elif isinstance(kv[1], bool):
|
||||
_hash_bool(hash_obj, kv[1])
|
||||
else:
|
||||
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
expected_options = 10
|
||||
|
@ -64,6 +64,8 @@ from jax import custom_derivatives as custom_derivatives_public
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
@ -1168,6 +1170,64 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
f_jit = self.jit(f)
|
||||
lowered = f_jit.lower(1.)
|
||||
lowered.compile( # doesn't crash
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
f_jit = self.jit(f)
|
||||
lowered = f_jit.lower(1.)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"invalid_key": "invalid_value"}))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "is not a valid bool value.",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_multiple(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
f_jit = self.jit(f)
|
||||
lowered = f_jit.lower(1.)
|
||||
|
||||
l1 = lowered.compile()
|
||||
l2 = lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
l3 = lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": False})
|
||||
|
||||
# Ideally we could test that these objects are different only in
|
||||
# that they respect the different options. Object identity is a
|
||||
# heuristic proxy for that.
|
||||
self.assertTrue(l1 is not l2)
|
||||
self.assertTrue(l1 is not l3)
|
||||
self.assertTrue(l2 is not l3)
|
||||
|
||||
# We should still error on invalid options after some valid compiles
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"invalid_key": "invalid_value"}))
|
||||
|
||||
def test_jit_enum_as_dict_keys_fails(self):
|
||||
class E(enum.Enum):
|
||||
A = 0
|
||||
|
@ -45,6 +45,8 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
@ -332,6 +334,62 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
lowered = f.lower(x)
|
||||
|
||||
lowered.compile( # doesn't crash
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
lowered = f.lower(x)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"invalid_key": "invalid_value"}))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "is not a valid bool value.",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_multiple(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
lowered = f.lower(x)
|
||||
|
||||
l1 = lowered.compile()
|
||||
l2 = lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
l3 = lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": False})
|
||||
|
||||
# Ideally we could test that these objects are different only in
|
||||
# that they respect the different options. Object identity is a
|
||||
# heuristic proxy for that.
|
||||
self.assertTrue(l1 is not l2)
|
||||
self.assertTrue(l1 is not l3)
|
||||
self.assertTrue(l2 is not l3)
|
||||
|
||||
# We should still error on invalid options after some valid compiles
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"invalid_key": "invalid_value"}))
|
||||
|
||||
def testLowerShapedArray(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user