1440 Commits

Author SHA1 Message Date
Sergei Lebedev
c9142cbe75 Collapsed a few unnecessary `if TYPE_CHECKING` blocks 2024-08-12 13:08:55 +01:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
Roy Frostig
c54ffd41bc in dot docstring, format and link to dot_general 2024-08-11 12:44:50 -07:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Dan Foreman-Mackey
23da11b609 Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
John Ryan
56ff247c2e Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb
PiperOrigin-RevId: 659334027
2024-08-04 12:11:30 -07:00
Matthew Johnson
bdcd358b65 improve while_loop carry pytree/type mismatch errors
Now we call into the same error utility as we use in scan.
2024-08-03 21:57:29 +00:00
Dan Foreman-Mackey
80560663d3 Enable FFI implementation of GPU Getrf FFI handler.
PiperOrigin-RevId: 658755392
2024-08-02 05:07:02 -07:00
jax authors
efba5f61b5 Merge pull request #22812 from superbobry:maint
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
Abhinav Gunjal
dfe8d94170 Integrate StableHLO at openxla/stablehlo@fb18ee25
PiperOrigin-RevId: 658515936
2024-08-01 13:23:01 -07:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
jax authors
aeff5b61a9 Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
vfdev-5
76d61f9d8f Added device kwargs to jnp.linspace, jnp.array, jnp.asarray 2024-07-26 00:36:34 +02:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
jax authors
76b4c70c23 Merge pull request #22628 from hawkinsp:broadcast2
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
52fa165d75 Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
jax authors
50c5613641 Merge pull request #22610 from mattjj:12719
PiperOrigin-RevId: 655358145
2024-07-23 17:20:05 -07:00
Matthew Johnson
8db862c02e fix memory leak in cond jaxpr tracig
fixes #12719
2024-07-23 23:57:02 +00:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sergei Lebedev
969431f1fc Removed unused `_broadcast_translate` 2024-07-22 22:47:49 +01:00
jax authors
9632a2d1a8 Add jvp and transpose rule for ragged dot.
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.

PiperOrigin-RevId: 654381378
2024-07-20 17:56:59 -07:00
jax authors
ac4ca35221 Merge pull request #22263 from hawkinsp:tuples
PiperOrigin-RevId: 653267867
2024-07-17 09:56:18 -07:00
Sergei Lebedev
c033653e28 Deduped three identical implementations of `hoist_consts_to_refs` 2024-07-12 12:48:28 +01:00
Roy Frostig
e8d9a54b1b extend type annotation for lax.convert_element_type
... to also accept extended dtypes (as defined internally).

PiperOrigin-RevId: 651372438
2024-07-11 05:27:11 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Peter Hawkins
3d5784a343 Don't wrap singleton ir.Types during HLO lowering.
This is similar to https://github.com/google/jax/pull/22211, but for MLIR types instead of MLIR values.
2024-07-08 12:24:45 -04:00
jax authors
a8f22f6e34 Merge pull request #19614 from cgarciae:batch_map
PiperOrigin-RevId: 649208823
2024-07-03 15:01:52 -07:00
jax authors
dffd72e290 Merge pull request #22211 from hawkinsp:singletons
PiperOrigin-RevId: 649135349
2024-07-03 11:07:00 -07:00
Cristian Garcia
557b273707 support axes and batching in map 2024-07-03 17:46:10 +01:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Peter Hawkins
ac3cb6f954 Simplify mlir.dense_int_array.
The NumPy array conversion here is pointless and slightly slower than not doing it.

PiperOrigin-RevId: 647520922
2024-06-27 19:33:06 -07:00
jax authors
43dc4c1ff8 Fix the jax.lax.Precision documentation.
- Make it clear that this only impacts FP32 computations.
- Remove incorrect aliases, eg. 'bfloat16' for default. This does not do as advertised for GPU.
- explicitly specify GPU and TPU device-dependent behaviour.

PiperOrigin-RevId: 647342888
2024-06-27 09:22:07 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
fe3c8e15a8 Merge pull request #21806 from cgarciae:cond-passthrough-outputs
PiperOrigin-RevId: 646970169
2024-06-26 09:13:07 -07:00
Cristian Garcia
dae7e41ade fix cond passthrough outputs 2024-06-26 16:17:45 +01:00
jax authors
fb4ab2baa1 Merge pull request #22055 from Intel-tensorflow:yimei/remove_block_fp16_oncpu
PiperOrigin-RevId: 646951088
2024-06-26 08:12:42 -07:00
Kevin Gleason
a4c92a454b Clean up gather/scatter StableHLO lowering.
PiperOrigin-RevId: 646491586
2024-06-25 08:39:50 -07:00
Yimei Sun
b37f51487d Remove the blocking for float16 dot on CPU platform to take advantage of CPU
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
2024-06-23 23:51:30 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
jax authors
be1f4ba380 Merge pull request #21905 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00