mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 02:46:11 +00:00

The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.