1196 Commits

Author SHA1 Message Date
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Peter Hawkins
d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
Jake VanderPlas
6cec5d4416 lax.pow: fix shape mismatch failure in jvp rule 2023-08-25 10:05:55 -07:00
jax authors
af42359433 Merge pull request #16419 from mattjj:pow-jvp
PiperOrigin-RevId: 559266945
2023-08-22 17:15:04 -07:00
Matthew Johnson
1f8fb2c8bd change lowering rule to satisfy jax2tf 2023-08-22 16:48:11 -07:00
jax authors
209b6b02f4 Merge pull request #17144 from jakevdp:zeta
PiperOrigin-RevId: 558193896
2023-08-18 11:04:43 -07:00
Peter Hawkins
889489206b Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
2023-08-17 06:44:12 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
0ad6196ff0 Create lax.polygamma with native HLO lowering 2023-08-16 11:57:05 -07:00
Peter Hawkins
47651c6a59 Remove uses of XLA translation rules.
Remove translation_rule argument to standard_primitive.

PiperOrigin-RevId: 557220350
2023-08-15 12:53:36 -07:00
Peter Hawkins
78cfdd1b35 Add some more type annotations to lax_numpy.py.
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.

PiperOrigin-RevId: 555955257
2023-08-11 08:07:24 -07:00
Peter Hawkins
bfaffe3183 Add version guards after GPU tridiagonal solve change.
PiperOrigin-RevId: 555931222
2023-08-11 06:41:05 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
jax authors
be543f020d Merge pull request #17041 from mtsokol:update-ninf-usage
PiperOrigin-RevId: 555185385
2023-08-09 09:26:12 -07:00
Mateusz Sokół
1fedf04ed5 API: Remove NINF and PINF usages 2023-08-09 14:16:33 +02:00
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Patrick Kidger
808b0b26be Crash fix from ndim 2023-08-04 13:32:12 -07:00
Jérome Eertmans
6e55c20fbb chore(docs): improve jax.lax.scan
Make the docstring a bit more explicit about what is t

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-08-03 22:18:07 +02:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
George Necula
2eaf545a47 [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-31 10:54:05 +03:00
Matthew Johnson
69ad4df9a5 fix pow_p jvp rule at x=0. y=0
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
2023-07-28 17:14:47 -07:00
jax authors
b716f433a5 Merge pull request #16883 from mattjj:exp2-primitive
PiperOrigin-RevId: 551946122
2023-07-28 14:09:05 -07:00
Matthew Johnson
560ede0ff1 add an exp2 primitive and lax.exp2
part of fixing https://github.com/jax-ml/jax-triton/issues/204
2023-07-28 12:33:49 -07:00
Peter Hawkins
9a21ff0780 Revert: [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.

PiperOrigin-RevId: 551915175
2023-07-28 12:05:22 -07:00
George Necula
88e11ae98c [shape_poly] Add shape polymorphism support for TopK.
This relies on a newly introduced support for a custom
call @stablehlo.dynamic_top_k.

PiperOrigin-RevId: 551833809
2023-07-28 06:19:38 -07:00
Peter Hawkins
a480aa8dbd Work around pytype error.
An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype.

PiperOrigin-RevId: 551627521
2023-07-27 13:39:48 -07:00
jax authors
416814df2a Merge pull request #16826 from mattjj:issue16805
PiperOrigin-RevId: 551263673
2023-07-26 11:20:31 -07:00
George Necula
c9f9f28b2c [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-26 12:29:12 +02:00
Jake VanderPlas
0dbda849ef lax.dynamic_slice: avoid negative index correction for unsigned indices 2023-07-25 13:09:09 -07:00
Jake VanderPlas
e1a1377cde replace use of has_opaque_dtype 2023-07-24 14:46:58 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Matthew Johnson
9ddef5cf84 make _dot_general_batch_rule handle python builtin numeric types 2023-07-24 14:01:07 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
65751bb328 make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-20 09:21:55 -07:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
jax authors
6c699815bc Merge pull request #16718 from mattjj:scatter-apply-autodiff
PiperOrigin-RevId: 548144853
2023-07-14 09:24:12 -07:00
jax authors
0e538e559d Merge pull request #16713 from gnecula:poly_clean4
PiperOrigin-RevId: 548092798
2023-07-14 04:52:17 -07:00
Jake VanderPlas
2cfffb613f make scatter-apply jvp notimplemented (for now...)
cf. #16684

Co-authored-by: Jake Vanderplas <vanderplas@google.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-13 16:25:25 -07:00
jax authors
ed302cbdda Merge pull request #16685 from axch:ragged-jit
PiperOrigin-RevId: 547833923
2023-07-13 10:03:06 -07:00
George Necula
71ac0bb446 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 16:37:53 +03:00
George Necula
58d6c4c1ec Roll back #16689
PiperOrigin-RevId: 547773322
2023-07-13 06:05:50 -07:00
George Necula
d21a667235 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 09:59:41 +03:00
Alexey Radul
f97db31ead Fix type errors caught by pytype. 2023-07-12 15:11:14 -04:00
Alexey Radul
ef9f1cbec3 Force bint-typed arrays to int32 types underneath.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-12 10:56:22 -04:00
Alexey Radul
60bec7a17b Physical HLO sharding for bint is the same as for the base type. 2023-07-11 15:21:55 -04:00