mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Why? This is generally used for static operations on shapes, but np.prod has an unfortunate corner-case behavior that np.prod([]) returns a float. math.prod is available as of Python 3.8, and is a better solution here.