mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 17:16:06 +00:00

Similarly to `jnp.einsum`, whenever we encounter an extension to the positional NumPy API (in the case of reductions, the extension is whenever a non-integer axis is specified), we reroute the call to a parallel primitive instead of the standard lax reductions. Note that this makes the parallel primitives implement a strict subset of functionality of the lax reductions so in the future (when we decide that we want axes to be truly first class) we can always swap out the implementation for the parallel version. But, it makes sense to keep them separate for the ease of prototyping in the near future.