* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.
Fixes https://github.com/google/jax/issues/12008
PiperOrigin-RevId: 488749199
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
Including blackman, bartlett, hamming, hanning, kaiser.
Why? Previously these were implemented by computing the output on host at trace-time and embedding the result as a large constant array. Computing the results via lax operations is more in the spirit of jax.numpy.
Including blackman, bartlett, hamming, hanning, kaiser.
Why? Previously these were implemented by embedding large constants; this should be more performant.
A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16.
In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function.
PiperOrigin-RevId: 476494989
The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation.
PiperOrigin-RevId: 476391389
--
0cf7b33e9e166f21a05bbddb04f95dad89a5f7a9 by Jake VanderPlas <jakevdp@google.com>:
jnp.remainder: match numpy's behavior for integer zero denominator
PiperOrigin-RevId: 468621345
We frequently use the pattern
try:
import m
except ImportError:
# do something else.
This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
Since the np.stack group is getting a dtype argument in numpy 1.24, they
should also have it in JAX.
Because they are just wrappers of np.concatenate, the changes are small.