151 Commits

Author SHA1 Message Date
Charles Hofer
63e6442bdf Merge branch 'rocm-main' into ci-upstream-sync-97_1 2025-01-27 17:19:08 +00:00
jax authors
aed79707e2 Merge pull request #25791 from mattjj:logsumexp-where-grad-nan
PiperOrigin-RevId: 714118085
2025-01-10 11:27:35 -08:00
Dan Foreman-Mackey
5f3e0d9e5e Add sph_harm_y to jax.scipy.special and deprecate sph_harm. 2025-01-09 12:53:00 -05:00
Matthew Johnson
f0392a1535 fix grad(logsumexp) to produce 0s where where is False 2025-01-08 23:38:06 +00:00
Charles Hofer
a94ee1fcb9 Unskip unit tests that are now fixed 2025-01-07 16:56:35 +00:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Charles Hofer
708f48dad6 Skip one more test 2025-01-06 17:04:39 +00:00
Charles Hofer
307f0db702 Skip failing tests 2025-01-06 16:40:38 +00:00
Dan Foreman-Mackey
512d5450ae Temporarily allow deprecation warnings for scipy.special.lpmn and scipy.special.sph_harm.
These functions are deprecated in scipy 1.15.0. I'll fix this properly soon, but let's start by getting CI working again!

PiperOrigin-RevId: 712512363
2025-01-06 06:50:51 -08:00
jax authors
629be0b701 Tighten test tolerances after the underlying issue causing nondeterministic results for _nrm2 in Eigen BLAS was fixed in https://gitlab.com/libeigen/eigen/-/merge_requests/1667 -> cl/663346025
PiperOrigin-RevId: 676881791
2024-09-20 10:03:46 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
1b2ba9d1c2 Disable two lax_scipy_test testcases that fail on TPU v6e.
PiperOrigin-RevId: 672973757
2024-09-10 08:26:27 -07:00
Peter Hawkins
1516d59744 Reverts 6fc57c0eb6f06b2da20c94f5f127fe4a551bda09
PiperOrigin-RevId: 663334727
2024-08-15 09:33:31 -07:00
Peter Hawkins
323e257f67 Fix test failures.
PiperOrigin-RevId: 662703221
2024-08-13 17:02:14 -07:00
Sergei Lebedev
6fc57c0eb6 Rolling forward #22836
This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes #22398 directly.

Reverts e416c6675acfd82866a6e83e8c221640c4d02f29

PiperOrigin-RevId: 660438802
2024-08-07 10:17:55 -07:00
Sergei Lebedev
e416c6675a Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995
PiperOrigin-RevId: 659670593
2024-08-05 13:52:04 -07:00
Sergei Lebedev
0a48aca965 Added a custom JVP rule for jax.nn.logsumexp
Fixes #22398 where the Jacobian of jax.nn.logsumexp was wrong if b= contained
exact zeros.
2024-08-05 17:05:03 +01:00
Peter Hawkins
858dc54590 Fix or disable some tests that fail when using a Eigen BLAS with AVX vectorization.
PiperOrigin-RevId: 658047868
2024-07-31 10:06:45 -07:00
Jake VanderPlas
f556a17033 TST: fix Lpmn test for new scipy 2024-05-08 15:55:20 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
carlosgmartin
e98612e2ab Add where argument to logsumexp. 2024-04-08 12:57:06 -04:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Adrian Kuegel
d4965cd335 [XLA:GPU] Clean up Target util.
We have some differences between Triton codegen and other fusion codegen,
namely for Remainder/Fmod and Cbrt. Unify that.

- Remove two unused math functions.
- Add mapping from kRemainder to kFmod.
- Use kCbrt device function in elemental_ir_emitter.

PiperOrigin-RevId: 567274915
2023-09-21 05:12:06 -07:00
Peter Hawkins
f863cfbaad Relax some test tolerances to fix failures on Linux aarch64.
PiperOrigin-RevId: 565930178
2023-09-16 06:55:22 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
ad7878df9d Add test of private xlogx gradient 2023-04-25 14:31:00 -07:00
Jake VanderPlas
9ac3781c7e grad(entr)(0.0): return inf instead of NaN 2023-04-25 08:32:37 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
dd023e266e jax.scipy.special: fix gradient for xlogy & xlog1py 2023-04-18 15:56:32 -07:00
pizzud
ef28dcf091 lax_scipy_test: Split into three targets, take 2.
The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.

Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 514440062
2023-03-06 09:53:23 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Ian Horn
a35fe206a1 Added more accurate version of the betaln function. 2022-11-29 11:56:07 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
9bb2c999d6 Reenable some tests disabled in the past because of an LLVM bug.
The issue no longer reproduces at head.

PiperOrigin-RevId: 480505525
2022-10-11 18:59:37 -07:00
jax authors
363cc124e3 Merge pull request #12197 from ROCmSoftwarePlatform:fixedRocmUnitTestsSkip
PiperOrigin-RevId: 479566021
2022-10-07 06:36:11 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Chao Chen
1c69f594fb testSphHarmOrderZeroDegreeOne and test_custom_linear_solve_cholesky have been fixed in ROCm, no need to skip 2022-09-01 13:27:23 +00:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Matthew Johnson
68e3f58041 un-skip polar/qdwh decomp tests skipped on gpu in ad6ce74
On an A100 machine, these tests seem to run fine now. See https://github.com/google/jax/issues/8628#issuecomment-1215651697.
2022-08-15 12:31:43 -07:00
Jake VanderPlas
7cc6b4f62b Tests: remove obsolete dtype_promotion decorators 2022-08-11 14:31:30 -07:00