147 Commits

Author SHA1 Message Date
Jake VanderPlas
b2c45b8eb9 Improved errors when indexing with floats 2025-02-28 15:04:07 -08:00
Jake VanderPlas
54fbf0b3f2 Indexing: avoid dynamic_slice when mode='clip'
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
Jake VanderPlas
f6f4ef06cd Fix indexing corner case with empty ellipses 2024-12-03 17:20:40 -08:00
Sergei Lebedev
4cf33c0239 Added scatter_sub_p
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 681754037
2024-10-03 00:27:31 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
b45f0fe50f Support empty boolean indexing 2024-08-06 09:56:03 -07:00
Jake VanderPlas
9a88ecb244 Improve error when indexing with too many indices 2024-07-31 13:57:48 -07:00
Jake VanderPlas
5b28170b94 Support scalar boolean indices in arr.at[idx].set(vals) 2024-05-20 05:33:36 -07:00
jax authors
daab7a0329 Handle ellipsis ... in _attempt_rewriting_take_via_slice.
Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns.

PiperOrigin-RevId: 631469837
2024-05-07 10:30:08 -07:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
George Necula
ec460585c8 Fix indexing with slices when the slice elements are jax.Array.
This fixes a bug introduced in #18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00
George Necula
d2f62612d7 Fix bug in indexing with slices that overflow, and add tests.
This bug was introduced in #18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
Jake VanderPlas
416b734567 Fix boolean indexing check with newaxis 2023-11-15 09:03:15 -08:00
Peter Hawkins
e7f1d29716 Relax some test tolerances for TPU.
PiperOrigin-RevId: 576192162
2023-10-24 10:45:40 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Adam Paszke
bb8d5a0121 Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Jake VanderPlas
52dad895fc remove stray print statement 2023-07-14 14:42:34 -07:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
01045b3b42 BUG: fix x.at[i].apply() with non-unit slice sizes 2023-04-17 15:27:03 -07:00
Jake VanderPlas
dca23d4d8f jax.numpy indexing: lower to dynamic_slice for more cases 2023-04-17 11:07:18 -07:00
Jake VanderPlas
e061b91ffc Fix uint32 scatter assignment 2023-04-10 14:24:26 -07:00
Jake VanderPlas
dd8033bdd4 Improve error for indexing with string 2023-03-20 08:55:16 -07:00
Jake VanderPlas
6dd0e0153a jnp.ndarray.at: deprecate passing additional arguments by position 2023-03-13 10:04:39 -07:00
jax authors
78599e65d1 Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
PiperOrigin-RevId: 504628432
2023-01-25 12:15:36 -08:00
jax authors
d14e144651 Use pareto optimal step size for computing numerical Jacobians in JAX. This allows us to tighten the tolerances in gradient unit testing significantly, especially for float64 and complex128.
PiperOrigin-RevId: 504579516
2023-01-25 09:12:52 -08:00
Jake VanderPlas
b037feb105 [x64] more type safety for lax_numpy-related tests 2022-12-01 11:18:02 -08:00
Jake VanderPlas
376c55d66b Add JaxTestCase.assertNoWarnings 2022-10-12 10:19:21 -07:00
Peter Hawkins
8107e3600e Switch lax_numpy_indexing_test to use jtu.sample_product. 2022-10-06 17:44:17 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
cc8d406bdb Copybara import of the project:
--
d5ee729a7f5d583c8e212c6a9f1b98221b13cbdc by Jake VanderPlas <jakevdp@google.com>:

generate lax.slice instead of lax.gather for more indexing cases

PiperOrigin-RevId: 470094833
2022-08-25 15:14:33 -07:00
Jake VanderPlas
d5ee729a7f generate lax.slice instead of lax.gather for more indexing cases 2022-08-25 13:04:16 -07:00
Jake VanderPlas
dec2e8c577 [x64] make lax_numpy_indexing_test pass with strict dtype promotion 2022-06-16 14:00:11 -07:00
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Lukas Geiger
f13b69c41d Avoid generating trivial gathers when reversing array 2022-05-11 23:16:40 +01:00
Anselm Levskaya
882a2d5dd3 Rollback of PR #10393 "Improve performance of array integer indexing"
This PR has broken some user models so needs to be investigated further before merging.

PiperOrigin-RevId: 447756000
2022-05-10 09:44:10 -07:00
Lukas Geiger
167c6a9f0c Expand constant indexing test to check slice 2022-05-05 23:47:00 +01:00
Lukas Geiger
5e2dd9ccd4 Add jaxpr test to ensure that no normalization happens for constant indices 2022-05-05 23:47:00 +01:00
Peter Hawkins
7c6a550333 Change the default scatter mode to FILL_OR_DROP.
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.

This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.

After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.

Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.

Issues: https://github.com/google/jax/issues/278 https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
2022-04-27 12:26:55 -07:00
Jake VanderPlas
92ca76a039 Skip normalization of unsigned indices 2022-04-20 16:04:12 -07:00