770 Commits

Author SHA1 Message Date
Jake VanderPlas
53676932e8 Error on numpy masked array inputs. 2022-12-27 15:42:49 -08:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Peter Hawkins
33a1b8866a Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
Jake VanderPlas
3cf2924ed6 [x64] minor fixes for lax_numpy_test type safety 2022-12-01 13:56:42 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Peter Hawkins
99e1c3dd66 [JAX] Opt into high precision matrix multiplications in JAX tests that fail on A100.
With these changes the JAX test suite passes on A100, which uses TF32 math by default. As a side effect, we can also remove a number of TPU-specific tolerances once we have opted into high precision.

Fixes https://github.com/google/jax/issues/12008

PiperOrigin-RevId: 488749199
2022-11-15 13:50:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Jake VanderPlas
0a79fa4f1c jax.numpy: implement window functions in terms of lax ops
Including blackman, bartlett, hamming, hanning, kaiser.

Why? Previously these were implemented by computing the output on host at trace-time and embedding the result as a large constant array. Computing the results via lax operations is more in the spirit of jax.numpy.
2022-10-27 15:47:04 -07:00
Jake VanderPlas
51242bcc26 jax.numpy: implement window functions in terms of lax ops
Including blackman, bartlett, hamming, hanning, kaiser.

Why? Previously these were implemented by embedding large constants; this should be more performant.
2022-10-27 15:08:16 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Jake VanderPlas
d94327c9e9 Move promote_like_jnp to jax.test_util 2022-10-06 10:20:26 -07:00
Jake VanderPlas
ff0810998e test: fix LaxNumpyTest:testConcatenate 2022-10-05 15:29:15 -07:00
Peter Hawkins
2c946b3b56 Migrate api_test, lax_numpy_test, and lax_vmap_test to
jtu.sample_product.

Gives a ~2x improvement in pytest --collect-only timing for
lax_numpy_test.
2022-10-05 13:46:19 +00:00
Jake VanderPlas
439217644a Split parts of lax_numpy_test.py into separate test files.
Why? The main test file is getting too big and this hinders iteration on individual tests

PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00
Jake VanderPlas
d49c5c37ea jnp.take: add optional arguments forwarded to lax.gather 2022-09-29 09:33:38 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
jax authors
0e116888ea Merge pull request #12382 from jakevdp:reduction-dtype
PiperOrigin-RevId: 477179725
2022-09-27 08:38:46 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Jake VanderPlas
1860f6d839 [x64] add promote_integers argument to jnp.prod & jnp.sum 2022-09-26 13:31:43 -07:00
Peter Hawkins
8ee7129874 Fix jnp.unwrap() test failures on GPU.
A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16.

In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function.

PiperOrigin-RevId: 476494989
2022-09-23 17:11:51 -07:00
Ke Wu
c823151771 Allow transpose axes to be negative to match (undocumented) NumPy behavior 2022-09-23 10:18:23 -07:00
Tres Popp
0c085471c7 Modify CorrCoef test to not rely on floating poing representation of 1/3
The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation.

PiperOrigin-RevId: 476391389
2022-09-23 09:39:01 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
a09ef8a6a6 Temporarily skip LaxBackedNumpyTests.testUnwrap on gpu to unblock jaxlib build
PiperOrigin-RevId: 475970440
2022-09-21 18:06:02 -07:00
George Karpenkov
541aadcfe8 [XLA:GPU] Allow simplifying lowering-precision-conversions by default
This might lead to the output having higher precision than specified by HLO.

PiperOrigin-RevId: 475889141
2022-09-21 12:04:45 -07:00
Jake VanderPlas
74698048f3 Tracer: add missing __round__ and __reversed__ methods 2022-09-20 09:09:23 -07:00
Yash Katariya
09a3796d50 Enable testArrayCopy now that its fixed.
PiperOrigin-RevId: 473088085
2022-09-08 14:40:06 -07:00
Jon Barron
dc4591dd6c Fix NaNs in the gradient of jnp.interp when the spline being interpolated into contains knots that are small and nearby.
PiperOrigin-RevId: 472511203
2022-09-06 11:22:47 -07:00
Benjamin Kramer
d50c2599d4 Re-enable lax_numpy test that triggered a nonterminating LLVM compilation.
LLVM bug was fixed in 12b203ea7c

PiperOrigin-RevId: 469781517
2022-08-24 11:33:30 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
jax authors
29482a2ef6 Copybara import of the project:
--
0cf7b33e9e166f21a05bbddb04f95dad89a5f7a9 by Jake VanderPlas <jakevdp@google.com>:

jnp.remainder: match numpy's behavior for integer zero denominator

PiperOrigin-RevId: 468621345
2022-08-18 22:08:21 -07:00
jax authors
4705073839 Merge pull request #11996 from jakevdp:fix-mod
PiperOrigin-RevId: 468595954
2022-08-18 18:33:30 -07:00
Yash Katariya
9244f3b1ba Add support for interoperability via dlpack for Array and also make pickle_tests and lax_numpy_test pass with Array.
PiperOrigin-RevId: 468568917
2022-08-18 16:04:22 -07:00
Jake VanderPlas
0cf7b33e9e jnp.remainder: match numpy's behavior for integer zero denominator 2022-08-18 13:18:21 -07:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Jake VanderPlas
d0ba24bc66 test: fix testStack for numpy 1.24 or newer 2022-08-17 11:51:02 -07:00
Albert Alonso
99c5e91874 add dtype arg to jnp.stack and friends
Since the np.stack group is getting a dtype argument in numpy 1.24, they
should also have it in JAX.

Because they are just wrappers of np.concatenate, the changes are small.
2022-08-16 19:45:41 +02:00
jax authors
d20dcf4b50 Merge pull request #11857 from jakevdp:fix-bool-args
PiperOrigin-RevId: 467259299
2022-08-12 11:45:35 -07:00
Jake VanderPlas
3f06195994 jax.numpy: improve support for boolean inputs 2022-08-12 09:51:25 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Jake VanderPlas
7cc6b4f62b Tests: remove obsolete dtype_promotion decorators 2022-08-11 14:31:30 -07:00
Jake VanderPlas
97c32f67fc Tests: reenable some ufunc input tests 2022-08-10 12:45:32 -07:00
Peter Hawkins
03590d86c0 Disable lax_numpy test that seems to lead to nonterminating LLVM compilation.
PiperOrigin-RevId: 466682802
2022-08-10 07:50:51 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Penn
1987ca7389 Add dtype arg to jnp.concatenate and update tests 2022-08-01 15:48:40 -07:00
Jake VanderPlas
4a693400b9 BUG: make jnp.iscomplexobj compatible with jit 2022-07-21 16:56:29 -07:00
Jake VanderPlas
10411bfeae jnp.searchsorted: add optional method argument to control implementation 2022-07-21 09:40:18 -07:00