Add a .cost_analysis() on lowered but uncompiled computations.

Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
This commit is contained in:
Peter Hawkins 2023-02-06 12:57:30 -08:00 committed by jax authors
parent f37f00d620
commit 3d9ae6b467
8 changed files with 94 additions and 2 deletions

View File

@ -13,7 +13,7 @@ Classes
:special-members: __call__
.. autoclass:: Lowered
:members: in_tree, out_tree, compile, as_text, compiler_ir
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
.. autoclass:: Compiled
:members: in_tree, out_tree, as_text, cost_analysis, memory_analysis, runtime_executable

View File

@ -1010,6 +1010,11 @@ class XlaComputation(stages.XlaLowering):
return self._executable
def cost_analysis(self) -> Dict[str, float]:
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
self.hlo().as_hlo_module())
@profiler.annotate_function
def backend_compile(backend, built_c, options, host_callbacks):
# we use a separate function call to ensure that XLA compilation appears

View File

@ -175,6 +175,26 @@ class Lowering(Protocol):
"""
raise NotImplementedError
def cost_analysis(self) -> Any:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it need not be consistent across versions of JAX
and jaxlib, or even across invocations. It is relayed directly to external
callers.
This function estimates execution cost in the absence of compiler
optimizations, which may drastically affect the cost. For execution cost
estimates after optimizations, compile this lowering and see
``Compiled.cost_analysis``.
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
compiler, or runtime.
"""
# TODO(frostig): improve annotation (arbitrary pytree)
raise NotImplementedError
# -- Internal adapters from XLA-related objects to the above protocols
@ -311,6 +331,9 @@ class XlaLowering(Lowering):
else:
raise ValueError(f"unknown dialect: {dialect}")
def cost_analysis(self) -> Dict[str, float]:
raise NotImplementedError("must override")
# -- Public-facing API, plus helpers
@ -538,7 +561,7 @@ class Lowered(Stage):
def __init__(
self,
lowering: XlaLowering,
args_info, # PyTreee of ArgInfo
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
self._lowering = lowering
@ -618,6 +641,24 @@ class Lowered(Stage):
except NotImplementedError:
return None
def cost_analysis(self) -> Optional[Any]:
"""A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by
this is some simple data structure that can easily be printed or serialized
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
structure can be arbitrary: it may be inconsistent across versions of JAX
and jaxlib, or even across invocations.
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
runtime.
"""
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
try:
return self._lowering.cost_analysis()
except NotImplementedError:
return None
class Wrapped(Protocol):
"""A function ready to be specialized, lowered, and compiled.

View File

@ -3206,6 +3206,9 @@ class MeshComputation(stages.XlaLowering):
self._executable = executable
return self._executable
def cost_analysis(self) -> Dict[str, float]:
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
self.hlo().as_hlo_module())
def _get_input_metadata(
global_in_avals: Sequence[ShapedArray],

View File

@ -64,6 +64,7 @@ from jax._src import config as jax_config
from jax._src import custom_derivatives
from jax._src import device_array
from jax._src import prng
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax import tree_util
@ -1108,6 +1109,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(f.as_text(), (str, type(None)))
self.assertIsInstance(g.as_text(), (str, type(None)))
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_cost_analysis(self):
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
if "PJRT C API" in xla_bridge.get_backend().platform_version:
raise unittest.SkipTest("C API does not support uncompiled cost analysis")
f = self.jit(lambda x: x).lower(1.)
g = self.jit(lambda x: x + 4).lower(1.)
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()

View File

@ -982,6 +982,20 @@ class PJitTest(jtu.BufferDonationTestCase):
f = f.lower(x, x + 1).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
@jtu.with_mesh([('x', 2), ('y', 2)])
@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),))
def f(x, y):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
f = f.lower(x, x + 1)
f.cost_analysis() # doesn't raise
@jtu.with_mesh([('x', 2), ('y', 2)])
@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):

View File

@ -299,6 +299,14 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
f.cost_analysis() # doesn't raise
@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')

View File

@ -766,6 +766,16 @@ class XMapTest(XMapTestCase):
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
if "PJRT C API" in xla_bridge.get_backend().platform_version:
raise SkipTest("C API does not support uncompiled cost analysis")
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
f = f.lower(x)
f.cost_analysis() # doesn't raise
@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])