mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 03:06:04 +00:00

Here are two desiderata for jax.numpy dtype promotion behavior: 1. follow what NumPy does 2. be invariant to `@jit` The latter is much more important, so whenever the two are in tension we prefer the latter. (Also we already can't do a perfect job following what NumPy does, e.g. around its value-dependent dtype promotion logic.) Issue #732 showed our code had a special behavior that essentially handled a case of the former desideratum but also broke the latter. #732 also showed us (again) that our tests really should cover Python scalars. In summary, in this commit: * revise jax.numpy dtype promotion behavior to be invariant to `@jit` * add Python scalar types to lax_numpy tests * simplify and update kron implementation to fix dtype issues