scipy accounts for around 400ms of the 900ms of JAX's import time. By loading scipy lazily, we can improve the timing of `import jax` down to about 500ms.