Adam Paszke f86bf12b5a Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
2021-02-01 11:41:05 +00:00
..
2021-01-25 14:08:57 -08:00
2020-11-09 10:40:14 +05:30