From 3d9ae6b46726fa986f3ae1a4d3d4bab875aab97d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 6 Feb 2023 12:57:30 -0800 Subject: [PATCH] Add a .cost_analysis() on lowered but uncompiled computations. Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it. PiperOrigin-RevId: 507560058 --- docs/jax.stages.rst | 2 +- jax/_src/dispatch.py | 5 +++++ jax/_src/stages.py | 43 +++++++++++++++++++++++++++++++++++++++- jax/interpreters/pxla.py | 3 +++ tests/api_test.py | 11 ++++++++++ tests/pjit_test.py | 14 +++++++++++++ tests/pmap_test.py | 8 ++++++++ tests/xmap_test.py | 10 ++++++++++ 8 files changed, 94 insertions(+), 2 deletions(-) diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst index dddbc1135..f8adce32b 100644 --- a/docs/jax.stages.rst +++ b/docs/jax.stages.rst @@ -13,7 +13,7 @@ Classes :special-members: __call__ .. autoclass:: Lowered - :members: in_tree, out_tree, compile, as_text, compiler_ir + :members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis .. autoclass:: Compiled :members: in_tree, out_tree, as_text, cost_analysis, memory_analysis, runtime_executable diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2efd7d3e3..a715ca19f 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -1010,6 +1010,11 @@ class XlaComputation(stages.XlaLowering): return self._executable + def cost_analysis(self) -> Dict[str, float]: + return xe.hlo_module_cost_analysis(self.compile_args["backend"], + self.hlo().as_hlo_module()) + + @profiler.annotate_function def backend_compile(backend, built_c, options, host_callbacks): # we use a separate function call to ensure that XLA compilation appears diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 83cef1198..f1dd3001e 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -175,6 +175,26 @@ class Lowering(Protocol): """ raise NotImplementedError + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + This function estimates execution cost in the absence of compiler + optimizations, which may drastically affect the cost. For execution cost + estimates after optimizations, compile this lowering and see + ``Compiled.cost_analysis``. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ + # TODO(frostig): improve annotation (arbitrary pytree) + raise NotImplementedError # -- Internal adapters from XLA-related objects to the above protocols @@ -311,6 +331,9 @@ class XlaLowering(Lowering): else: raise ValueError(f"unknown dialect: {dialect}") + def cost_analysis(self) -> Dict[str, float]: + raise NotImplementedError("must override") + # -- Public-facing API, plus helpers @@ -538,7 +561,7 @@ class Lowered(Stage): def __init__( self, lowering: XlaLowering, - args_info, # PyTreee of ArgInfo + args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): self._lowering = lowering @@ -618,6 +641,24 @@ class Lowered(Stage): except NotImplementedError: return None + def cost_analysis(self) -> Optional[Any]: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it may be inconsistent across versions of JAX + and jaxlib, or even across invocations. + + Returns ``None`` if unavailable, e.g. based on backend, compiler, or + runtime. + """ + # TODO(frostig): improve annotation (basic pytree of arbitrary structure) + try: + return self._lowering.cost_analysis() + except NotImplementedError: + return None + class Wrapped(Protocol): """A function ready to be specialized, lowered, and compiled. diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a605efb3c..11fbaa334 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -3206,6 +3206,9 @@ class MeshComputation(stages.XlaLowering): self._executable = executable return self._executable + def cost_analysis(self) -> Dict[str, float]: + return xe.hlo_module_cost_analysis(self.compile_args["backend"], + self.hlo().as_hlo_module()) def _get_input_metadata( global_in_avals: Sequence[ShapedArray], diff --git a/tests/api_test.py b/tests/api_test.py index 0dbe3ae44..2fe603142 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -64,6 +64,7 @@ from jax._src import config as jax_config from jax._src import custom_derivatives from jax._src import device_array from jax._src import prng +from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src import test_util as jtu from jax import tree_util @@ -1108,6 +1109,16 @@ class CPPJitTest(jtu.BufferDonationTestCase): self.assertIsInstance(f.as_text(), (str, type(None))) self.assertIsInstance(g.as_text(), (str, type(None))) + @jtu.skip_on_xla_cpu_mlir + def test_jit_lower_cost_analysis(self): + # TODO(b/261771737): add support for uncompiled cost analysis in C API. + if "PJRT C API" in xla_bridge.get_backend().platform_version: + raise unittest.SkipTest("C API does not support uncompiled cost analysis") + f = self.jit(lambda x: x).lower(1.) + g = self.jit(lambda x: x + 4).lower(1.) + f.cost_analysis() # doesn't raise + g.cost_analysis() # doesn't raise + @jtu.skip_on_xla_cpu_mlir def test_jit_lower_compile_cost_analysis(self): f = self.jit(lambda x: x).lower(1.).compile() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6211ca2fd..a312999c6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -982,6 +982,20 @@ class PJitTest(jtu.BufferDonationTestCase): f = f.lower(x, x + 1).compile() self.assertIsInstance(f.as_text(), (str, type(None))) + @jtu.with_mesh([('x', 2), ('y', 2)]) + @jtu.skip_on_xla_cpu_mlir + def testLowerCostAnalysis(self): + @partial(pjit, + in_axis_resources=P(('x', 'y'),), + out_axis_resources=P(('x', 'y'),)) + def f(x, y): + return x @ y + + shape = (8, 8) + x = jnp.arange(np.prod(shape)).reshape(shape) + f = f.lower(x, x + 1) + f.cost_analysis() # doesn't raise + @jtu.with_mesh([('x', 2), ('y', 2)]) @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 276a35906..3a17f271a 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -299,6 +299,14 @@ class PythonPmapTest(jtu.JaxTestCase): f = f.lower(x).compile() self.assertIsInstance(f.as_text(), (str, type(None))) + @jtu.skip_on_xla_cpu_mlir + def testLowerCostAnalysis(self): + f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') + shape = (jax.device_count(), 4) + x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + f = f.lower(x) + f.cost_analysis() # doesn't raise + @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 2128130f5..9a5a91305 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -766,6 +766,16 @@ class XMapTest(XMapTestCase): f = f.lower(x).compile() self.assertIsInstance(f.as_text(), (str, type(None))) + @jtu.skip_on_xla_cpu_mlir + def testLowerCostAnalysis(self): + # TODO(b/261771737): add support for uncompiled cost analysis in C API. + if "PJRT C API" in xla_bridge.get_backend().platform_version: + raise SkipTest("C API does not support uncompiled cost analysis") + f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]) + x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2)) + f = f.lower(x) + f.cost_analysis() # doesn't raise + @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])