From c2819cfd91a82bd8326383c16a5cfec5cba3a095 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 15 Feb 2023 01:49:55 +0000 Subject: [PATCH] MeshComputation.cost_analysis() isn't implemented with PJRT C API. This was caught via PJitTest.testLowerCostAnalysis (https://github.com/google/jax/blob/e74852f79695c9f2fcf06b4c8400b176bb899766/tests/pjit_test.py#L998). We don't need to change the test because NotImplementedError is already caught in Lowered.cost_analysis: https://github.com/google/jax/blob/e74852f79695c9f2fcf06b4c8400b176bb899766/jax/_src/stages.py#L659-L660 --- jax/_src/interpreters/pxla.py | 9 +++++++-- jax/_src/lib/xla_bridge.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3a1bb0d17..d25edca3d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3259,8 +3259,13 @@ class MeshComputation(stages.XlaLowering): return self._executable def cost_analysis(self) -> Dict[str, float]: - return xe.hlo_module_cost_analysis(self.compile_args["backend"], - self.hlo().as_hlo_module()) + backend = self.compile_args["backend"] + if xb.using_pjrt_c_api(backend): + raise NotImplementedError( + "Lowered.cost_analysis not implemented on platform " + f"'{backend.platform}'. Use compile().cost_analysis() for " + "post-compilation cost estimates.") + return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) def get_input_metadata( global_in_avals: Sequence[ShapedArray], diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index 7cacc392b..86c2c12dd 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -631,3 +631,7 @@ def host_ids(backend=None): "instead. jax.host_ids will eventually be removed; please update your " "code.") return list(range(process_count(backend))) + + +def using_pjrt_c_api(backend=None): + return "PJRT C API" in get_backend(backend).platform_version