1459 Commits

Author SHA1 Message Date
jax authors
eed273c106 Merge pull request #23353 from jakevdp:lax-deps
PiperOrigin-RevId: 670523237
2024-09-03 05:59:26 -07:00
Paweł Paruzel
414eb90f5b Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
2024-09-02 06:19:01 -07:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
Yash Katariya
164b884f33 Fix failing tests in CI
PiperOrigin-RevId: 669357019
2024-08-30 09:49:58 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Jake VanderPlas
1cba0970d8 refactor lax.loops to avoid importing from jax.numpy 2024-08-28 14:41:59 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Roy Frostig
b3e3115391 improve scan error message on non-concrete unroll argument 2024-08-24 23:09:12 -07:00
Roy Frostig
a9b41e9fe7 improve scan error message on non-concrete length argument
Specifically, make it speak concretely about the `length` argument.
2024-08-24 22:30:33 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
jax authors
f1974b6471 Merge pull request #22523 from mattjj:scan-attrs-fix
PiperOrigin-RevId: 664935564
2024-08-19 12:51:09 -07:00
Matthew Johnson
ef82cb21ae fix basic scan bug with attrs 2024-08-19 19:12:52 +00:00
Dan Foreman-Mackey
dad2f576ac Add support for shape polymorphism in ffi_lowering and move lu_pivots_to_permutation lowering out of jaxlib.
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.

Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.

Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!

Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.

PiperOrigin-RevId: 664680250
2024-08-19 01:05:31 -07:00
Peter Hawkins
ba5b081571 [numpy] Fix test failures under NumPy 2.0.
PiperOrigin-RevId: 664465687
2024-08-18 09:09:37 -07:00
jax authors
9785368c7f [Easy] Refactor ragged_dot transpose, combine ragged_to_dense
PiperOrigin-RevId: 663630185
2024-08-16 00:32:42 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
Sergei Lebedev
c9142cbe75 Collapsed a few unnecessary `if TYPE_CHECKING` blocks 2024-08-12 13:08:55 +01:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
Roy Frostig
c54ffd41bc in dot docstring, format and link to dot_general 2024-08-11 12:44:50 -07:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Dan Foreman-Mackey
23da11b609 Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
John Ryan
56ff247c2e Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb
PiperOrigin-RevId: 659334027
2024-08-04 12:11:30 -07:00
Matthew Johnson
bdcd358b65 improve while_loop carry pytree/type mismatch errors
Now we call into the same error utility as we use in scan.
2024-08-03 21:57:29 +00:00
Dan Foreman-Mackey
80560663d3 Enable FFI implementation of GPU Getrf FFI handler.
PiperOrigin-RevId: 658755392
2024-08-02 05:07:02 -07:00
jax authors
efba5f61b5 Merge pull request #22812 from superbobry:maint
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
Abhinav Gunjal
dfe8d94170 Integrate StableHLO at openxla/stablehlo@fb18ee25
PiperOrigin-RevId: 658515936
2024-08-01 13:23:01 -07:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
jax authors
aeff5b61a9 Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
vfdev-5
76d61f9d8f Added device kwargs to jnp.linspace, jnp.array, jnp.asarray 2024-07-26 00:36:34 +02:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
jax authors
76b4c70c23 Merge pull request #22628 from hawkinsp:broadcast2
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
52fa165d75 Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
jax authors
50c5613641 Merge pull request #22610 from mattjj:12719
PiperOrigin-RevId: 655358145
2024-07-23 17:20:05 -07:00
Matthew Johnson
8db862c02e fix memory leak in cond jaxpr tracig
fixes #12719
2024-07-23 23:57:02 +00:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sergei Lebedev
969431f1fc Removed unused `_broadcast_translate` 2024-07-22 22:47:49 +01:00
jax authors
9632a2d1a8 Add jvp and transpose rule for ragged dot.
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.

PiperOrigin-RevId: 654381378
2024-07-20 17:56:59 -07:00
jax authors
ac4ca35221 Merge pull request #22263 from hawkinsp:tuples
PiperOrigin-RevId: 653267867
2024-07-17 09:56:18 -07:00
Sergei Lebedev
c033653e28 Deduped three identical implementations of `hoist_consts_to_refs` 2024-07-12 12:48:28 +01:00
Roy Frostig
e8d9a54b1b extend type annotation for lax.convert_element_type
... to also accept extended dtypes (as defined internally).

PiperOrigin-RevId: 651372438
2024-07-11 05:27:11 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00