MeshComputation.cost_analysis() isn't implemented with PJRT C API.

This was caught via PJitTest.testLowerCostAnalysis
(e74852f796/tests/pjit_test.py (L998)). We
don't need to change the test because NotImplementedError is already
caught in Lowered.cost_analysis:
e74852f796/jax/_src/stages.py (L659-L660)
This commit is contained in:
Skye Wanderman-Milne 2023-02-15 01:49:55 +00:00
parent e74852f796
commit c2819cfd91
2 changed files with 11 additions and 2 deletions

View File

@ -3259,8 +3259,13 @@ class MeshComputation(stages.XlaLowering):
return self._executable
def cost_analysis(self) -> Dict[str, float]:
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
self.hlo().as_hlo_module())
backend = self.compile_args["backend"]
if xb.using_pjrt_c_api(backend):
raise NotImplementedError(
"Lowered.cost_analysis not implemented on platform "
f"'{backend.platform}'. Use compile().cost_analysis() for "
"post-compilation cost estimates.")
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
def get_input_metadata(
global_in_avals: Sequence[ShapedArray],

View File

@ -631,3 +631,7 @@ def host_ids(backend=None):
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
def using_pjrt_c_api(backend=None):
return "PJRT C API" in get_backend(backend).platform_version