We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,
- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
Add support for multivariate polynomial division on polymorphic sizes without remainder, allowing `mask` of
- `jnp.reshape` with -1 size and
- `lax.slice` for polymorphic stride for sizes `poly * stride`, i. e. `(n^2+2n)//n = n+2`
Also clean up `Poly` class, improve error messages.
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753