mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
MeshComputation.cost_analysis() isn't implemented with PJRT C API.
This was caught via PJitTest.testLowerCostAnalysis (e74852f796/tests/pjit_test.py (L998)
). We don't need to change the test because NotImplementedError is already caught in Lowered.cost_analysis:e74852f796/jax/_src/stages.py (L659-L660)
This commit is contained in:
parent
e74852f796
commit
c2819cfd91
@ -3259,8 +3259,13 @@ class MeshComputation(stages.XlaLowering):
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
return xe.hlo_module_cost_analysis(self.compile_args["backend"],
|
||||
self.hlo().as_hlo_module())
|
||||
backend = self.compile_args["backend"]
|
||||
if xb.using_pjrt_c_api(backend):
|
||||
raise NotImplementedError(
|
||||
"Lowered.cost_analysis not implemented on platform "
|
||||
f"'{backend.platform}'. Use compile().cost_analysis() for "
|
||||
"post-compilation cost estimates.")
|
||||
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
|
||||
|
||||
def get_input_metadata(
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
|
@ -631,3 +631,7 @@ def host_ids(backend=None):
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
def using_pjrt_c_api(backend=None):
|
||||
return "PJRT C API" in get_backend(backend).platform_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user