23 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Parker Schuh
030b6c655d Update the docs for conv_general_dilated to clarify 'W' 'H'. 2024-09-10 14:02:06 -07:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
Jake VanderPlas
1cba0970d8 refactor lax.loops to avoid importing from jax.numpy 2024-08-28 14:41:59 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Peter Hawkins
6f7be3cf04 Define lax.Precision directly in Python, rather than inheriting from a C++ type in jaxlib.
Historically, we defined Precision to be an enum exported from jaxlib using pybind11, since that was the type the old XLA ComputationBuilder classes expected as input. But we build IR using StableHLO MLIR builders these days, and there's no reason for the JAX-level Precision type to match the XLA-internal one.

In a future change I plan to change the definition of Precision in jaxlib to be defined using nanobind instead of pybind11. Nanobind defines its enum classes to be final by default, which precludes this inheritance, and that's probably a good design decision by nanobind. But as discussed above, there's no good reason to inherit in the first place.

PiperOrigin-RevId: 612575404
2024-03-04 14:01:31 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Anselm Levskaya
f6a5f0dca2 Use real Precision type for lax.PrecisionType
PiperOrigin-RevId: 432413742
2022-03-04 04:21:25 -08:00
Roman Novak
8977998b5c
Update type annotations and use the new convolution.py file 2022-01-07 14:54:41 -08:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Peter Hawkins
4204a25c91 Split convolution functions out of jax._src.lax.lax and into a separate module (jax._src.lax.convolution).
No public API changes.

PiperOrigin-RevId: 411871903
2021-11-23 12:35:50 -08:00
Jake VanderPlas
f2a959054a Document jax.lax.Precision 2021-11-08 14:15:31 -08:00
Rebecca Chen
5065e1bb93 Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376300297
2021-05-27 20:10:23 -07:00
Roman Novak
bc84c9fe8f Add lax.conv_general_dilated_local 2021-05-13 12:20:35 -07:00
Lukas Geiger
f7f42694d9 Add support for preferred_element_type arg in convolutions 2021-04-22 10:29:31 +02:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
ef5f6c9f4a jnp.eye(): improve input validation 2021-01-08 12:51:03 -08:00
Roman Novak
da0bff2fa8 Add lax.conv_general_dilated_patches 2020-10-20 22:58:53 -07:00