diff --git a/CHANGELOG.md b/CHANGELOG.md index 10bf7cebf..94cb4a2ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.38 (Dec 17, 2024) +* Breaking Changes + * `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a + single-element `list[dict[str, float]]`). + * Changes: * `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added as shortcuts of the corresponding `tree_util` functions. diff --git a/docs/aot.md b/docs/aot.md index cf32716a0..f4e7e020c 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -57,7 +57,7 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : >>> compiled = lowered.compile() >>> # Query for cost analysis, print FLOP estimate ->>> compiled.cost_analysis()[0]['flops'] +>>> compiled.cost_analysis()['flops'] 2.0 >>> # Execute the compiled function! diff --git a/jax/_src/stages.py b/jax/_src/stages.py index c1a188670..ed6febd76 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -250,15 +250,13 @@ class XlaExecutable(Executable): else: raise - # TODO(b/384741132): this should return a single dict (I think returning a list - # was to support MPMD executables, which never fully landed). - def cost_analysis(self) -> list[dict[str, float]]: + def cost_analysis(self) -> dict[str, float]: xla_ext_exe = self.xla_extension_executable() # TODO(b/259255524): Unify/merge the two cost_analysis calls below. if hasattr(xla_ext_exe, "cost_analysis"): try: - return [xla_ext_exe.cost_analysis()] + return xla_ext_exe.cost_analysis() except xla_extension.XlaRuntimeError as e: msg, *_ = e.args if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): @@ -276,11 +274,9 @@ class XlaExecutable(Executable): " were found)." ) - return [ - xla_extension.hlo_module_cost_analysis( - xla_ext_exe.client, hlo_modules[0] - ) - ] + return xla_extension.hlo_module_cost_analysis( + xla_ext_exe.client, hlo_modules[0] + ) except xla_extension.XlaRuntimeError as e: msg, *_ = e.args supported = not (type(msg) is str and @@ -295,7 +291,7 @@ class XlaExecutable(Executable): and hasattr(self.unsafe_call, "compiled") and hasattr(self.unsafe_call.compiled, "cost_analysis") ): - return [self.unsafe_call.compiled.cost_analysis()] + return self.unsafe_call.compiled.cost_analysis() raise NotImplementedError( f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}" diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 7058cacd8..746485d68 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1617,7 +1617,7 @@ class PallasCallTest(PallasBaseTest): flops=1234, transcendentals=21, bytes_accessed=12345 ), ) - (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() + analysis_result = jax.jit(f).lower(x).compile().cost_analysis() self.assertEqual(analysis_result['flops'], 1234) self.assertEqual(analysis_result['transcendentals'], 21) self.assertEqual(analysis_result['bytes accessed'], 12345) @@ -1635,7 +1635,7 @@ class PallasCallTest(PallasBaseTest): ), ) f = jax.vmap(f) - (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() + analysis_result = jax.jit(f).lower(x).compile().cost_analysis() self.assertEqual(analysis_result['flops'], batch_size * 1234) self.assertEqual(analysis_result['transcendentals'], batch_size * 21) self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)