mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify. This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API. PiperOrigin-RevId: 715837855
This commit is contained in:
parent
70c1ee5d9c
commit
2d72e8de84
@ -63,6 +63,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
* `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a
|
||||
single-element `list[dict[str, float]]`).
|
||||
|
||||
* Changes:
|
||||
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
|
||||
as shortcuts of the corresponding `tree_util` functions.
|
||||
|
@ -57,7 +57,7 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 :
|
||||
>>> compiled = lowered.compile()
|
||||
|
||||
>>> # Query for cost analysis, print FLOP estimate
|
||||
>>> compiled.cost_analysis()[0]['flops']
|
||||
>>> compiled.cost_analysis()['flops']
|
||||
2.0
|
||||
|
||||
>>> # Execute the compiled function!
|
||||
|
@ -250,15 +250,13 @@ class XlaExecutable(Executable):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(b/384741132): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed).
|
||||
def cost_analysis(self) -> list[dict[str, float]]:
|
||||
def cost_analysis(self) -> dict[str, float]:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
|
||||
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
|
||||
if hasattr(xla_ext_exe, "cost_analysis"):
|
||||
try:
|
||||
return [xla_ext_exe.cost_analysis()]
|
||||
return xla_ext_exe.cost_analysis()
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
|
||||
@ -276,11 +274,9 @@ class XlaExecutable(Executable):
|
||||
" were found)."
|
||||
)
|
||||
|
||||
return [
|
||||
xla_extension.hlo_module_cost_analysis(
|
||||
return xla_extension.hlo_module_cost_analysis(
|
||||
xla_ext_exe.client, hlo_modules[0]
|
||||
)
|
||||
]
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
supported = not (type(msg) is str and
|
||||
@ -295,7 +291,7 @@ class XlaExecutable(Executable):
|
||||
and hasattr(self.unsafe_call, "compiled")
|
||||
and hasattr(self.unsafe_call.compiled, "cost_analysis")
|
||||
):
|
||||
return [self.unsafe_call.compiled.cost_analysis()]
|
||||
return self.unsafe_call.compiled.cost_analysis()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"
|
||||
|
@ -1617,7 +1617,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
),
|
||||
)
|
||||
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
analysis_result = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
self.assertEqual(analysis_result['flops'], 1234)
|
||||
self.assertEqual(analysis_result['transcendentals'], 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], 12345)
|
||||
@ -1635,7 +1635,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
),
|
||||
)
|
||||
f = jax.vmap(f)
|
||||
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
analysis_result = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
self.assertEqual(analysis_result['flops'], batch_size * 1234)
|
||||
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
|
||||
|
Loading…
x
Reference in New Issue
Block a user