Jax: Stop returning a list of cost-analyses.

As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
This commit is contained in:
Zac Mustin 2025-01-15 09:53:29 -08:00 committed by jax authors
parent 70c1ee5d9c
commit 2d72e8de84
4 changed files with 13 additions and 13 deletions

View File

@ -63,6 +63,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.38 (Dec 17, 2024)
* Breaking Changes
* `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a
single-element `list[dict[str, float]]`).
* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.

View File

@ -57,7 +57,7 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 :
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
>>> compiled.cost_analysis()['flops']
2.0
>>> # Execute the compiled function!

View File

@ -250,15 +250,13 @@ class XlaExecutable(Executable):
else:
raise
# TODO(b/384741132): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed).
def cost_analysis(self) -> list[dict[str, float]]:
def cost_analysis(self) -> dict[str, float]:
xla_ext_exe = self.xla_extension_executable()
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
if hasattr(xla_ext_exe, "cost_analysis"):
try:
return [xla_ext_exe.cost_analysis()]
return xla_ext_exe.cost_analysis()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
@ -276,11 +274,9 @@ class XlaExecutable(Executable):
" were found)."
)
return [
xla_extension.hlo_module_cost_analysis(
return xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
supported = not (type(msg) is str and
@ -295,7 +291,7 @@ class XlaExecutable(Executable):
and hasattr(self.unsafe_call, "compiled")
and hasattr(self.unsafe_call.compiled, "cost_analysis")
):
return [self.unsafe_call.compiled.cost_analysis()]
return self.unsafe_call.compiled.cost_analysis()
raise NotImplementedError(
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"

View File

@ -1617,7 +1617,7 @@ class PallasCallTest(PallasBaseTest):
flops=1234, transcendentals=21, bytes_accessed=12345
),
)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
analysis_result = jax.jit(f).lower(x).compile().cost_analysis()
self.assertEqual(analysis_result['flops'], 1234)
self.assertEqual(analysis_result['transcendentals'], 21)
self.assertEqual(analysis_result['bytes accessed'], 12345)
@ -1635,7 +1635,7 @@ class PallasCallTest(PallasBaseTest):
),
)
f = jax.vmap(f)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
analysis_result = jax.jit(f).lower(x).compile().cost_analysis()
self.assertEqual(analysis_result['flops'], batch_size * 1234)
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)