mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it. PiperOrigin-RevId: 507560058
This commit is contained in:
parent
f37f00d620
commit
3d9ae6b467
@ -13,7 +13,7 @@ Classes
|
||||
:special-members: __call__
|
||||
|
||||
.. autoclass:: Lowered
|
||||
:members: in_tree, out_tree, compile, as_text, compiler_ir
|
||||
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
|
||||
|
||||
.. autoclass:: Compiled
|
||||
:members: in_tree, out_tree, as_text, cost_analysis, memory_analysis, runtime_executable
|
||||
|
@ -1010,6 +1010,11 @@ class XlaComputation(stages.XlaLowering):
|
||||
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
|
||||
self.hlo().as_hlo_module())
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def backend_compile(backend, built_c, options, host_callbacks):
|
||||
# we use a separate function call to ensure that XLA compilation appears
|
||||
|
@ -175,6 +175,26 @@ class Lowering(Protocol):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cost_analysis(self) -> Any:
|
||||
"""A summary of execution cost estimates.
|
||||
|
||||
Intended for visualization and debugging purposes. The object output by
|
||||
this is some simple data structure that can easily be printed or serialized
|
||||
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||||
structure can be arbitrary: it need not be consistent across versions of JAX
|
||||
and jaxlib, or even across invocations. It is relayed directly to external
|
||||
callers.
|
||||
|
||||
This function estimates execution cost in the absence of compiler
|
||||
optimizations, which may drastically affect the cost. For execution cost
|
||||
estimates after optimizations, compile this lowering and see
|
||||
``Compiled.cost_analysis``.
|
||||
|
||||
May raise ``NotImplementedError`` if unavailable, e.g. based on backend,
|
||||
compiler, or runtime.
|
||||
"""
|
||||
# TODO(frostig): improve annotation (arbitrary pytree)
|
||||
raise NotImplementedError
|
||||
|
||||
# -- Internal adapters from XLA-related objects to the above protocols
|
||||
|
||||
@ -311,6 +331,9 @@ class XlaLowering(Lowering):
|
||||
else:
|
||||
raise ValueError(f"unknown dialect: {dialect}")
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
|
||||
# -- Public-facing API, plus helpers
|
||||
|
||||
@ -538,7 +561,7 @@ class Lowered(Stage):
|
||||
def __init__(
|
||||
self,
|
||||
lowering: XlaLowering,
|
||||
args_info, # PyTreee of ArgInfo
|
||||
args_info, # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False):
|
||||
self._lowering = lowering
|
||||
@ -618,6 +641,24 @@ class Lowered(Stage):
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
def cost_analysis(self) -> Optional[Any]:
|
||||
"""A summary of execution cost estimates.
|
||||
|
||||
Intended for visualization and debugging purposes. The object output by
|
||||
this is some simple data structure that can easily be printed or serialized
|
||||
(e.g. nested dicts, lists, and tuples with numeric leaves). However, its
|
||||
structure can be arbitrary: it may be inconsistent across versions of JAX
|
||||
and jaxlib, or even across invocations.
|
||||
|
||||
Returns ``None`` if unavailable, e.g. based on backend, compiler, or
|
||||
runtime.
|
||||
"""
|
||||
# TODO(frostig): improve annotation (basic pytree of arbitrary structure)
|
||||
try:
|
||||
return self._lowering.cost_analysis()
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
|
||||
class Wrapped(Protocol):
|
||||
"""A function ready to be specialized, lowered, and compiled.
|
||||
|
@ -3206,6 +3206,9 @@ class MeshComputation(stages.XlaLowering):
|
||||
self._executable = executable
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
|
||||
self.hlo().as_hlo_module())
|
||||
|
||||
def _get_input_metadata(
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
|
@ -64,6 +64,7 @@ from jax._src import config as jax_config
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import device_array
|
||||
from jax._src import prng
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
@ -1108,6 +1109,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
self.assertIsInstance(g.as_text(), (str, type(None)))
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_cost_analysis(self):
|
||||
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
|
||||
if "PJRT C API" in xla_bridge.get_backend().platform_version:
|
||||
raise unittest.SkipTest("C API does not support uncompiled cost analysis")
|
||||
f = self.jit(lambda x: x).lower(1.)
|
||||
g = self.jit(lambda x: x + 4).lower(1.)
|
||||
f.cost_analysis() # doesn't raise
|
||||
g.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_compile_cost_analysis(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
|
@ -982,6 +982,20 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = f.lower(x, x + 1).compile()
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCostAnalysis(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=P(('x', 'y'),),
|
||||
out_axis_resources=P(('x', 'y'),))
|
||||
def f(x, y):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1)
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCompileCostAnalysis(self):
|
||||
|
@ -299,6 +299,14 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCostAnalysis(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x)
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCompileCostAnalysis(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
@ -766,6 +766,16 @@ class XMapTest(XMapTestCase):
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCostAnalysis(self):
|
||||
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
|
||||
if "PJRT C API" in xla_bridge.get_backend().platform_version:
|
||||
raise SkipTest("C API does not support uncompiled cost analysis")
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
f = f.lower(x)
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def testLowerCompileCostAnalysis(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
|
Loading…
x
Reference in New Issue
Block a user