mirror of
https://github.com/ROCm/jax.git
synced 2025-04-21 06:16:06 +00:00

A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`. This is what the HLO looks like for a tangent composite: ``` module @jit_my_tangent_composite { func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) { %0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64> return %0 : tensor<4xf64> } func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> { %0 = stablehlo.sine %arg0 : tensor<4xf64> %1 = stablehlo.cosine %arg0 : tensor<4xf64> %2 = stablehlo.divide %0, %1 : tensor<4xf64> return %2 : tensor<4xf64> } } ``` Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention". This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction. For more details on the API, refer to the documentation. PiperOrigin-RevId: 707750633