1175 Commits

Author SHA1 Message Date
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Patrick Kidger
808b0b26be Crash fix from ndim 2023-08-04 13:32:12 -07:00
Jérome Eertmans
6e55c20fbb chore(docs): improve jax.lax.scan
Make the docstring a bit more explicit about what is t

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-08-03 22:18:07 +02:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
George Necula
2eaf545a47 [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-31 10:54:05 +03:00
jax authors
b716f433a5 Merge pull request #16883 from mattjj:exp2-primitive
PiperOrigin-RevId: 551946122
2023-07-28 14:09:05 -07:00
Matthew Johnson
560ede0ff1 add an exp2 primitive and lax.exp2
part of fixing https://github.com/jax-ml/jax-triton/issues/204
2023-07-28 12:33:49 -07:00
Peter Hawkins
9a21ff0780 Revert: [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.

PiperOrigin-RevId: 551915175
2023-07-28 12:05:22 -07:00
George Necula
88e11ae98c [shape_poly] Add shape polymorphism support for TopK.
This relies on a newly introduced support for a custom
call @stablehlo.dynamic_top_k.

PiperOrigin-RevId: 551833809
2023-07-28 06:19:38 -07:00
Peter Hawkins
a480aa8dbd Work around pytype error.
An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype.

PiperOrigin-RevId: 551627521
2023-07-27 13:39:48 -07:00
jax authors
416814df2a Merge pull request #16826 from mattjj:issue16805
PiperOrigin-RevId: 551263673
2023-07-26 11:20:31 -07:00
George Necula
c9f9f28b2c [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-26 12:29:12 +02:00
Jake VanderPlas
0dbda849ef lax.dynamic_slice: avoid negative index correction for unsigned indices 2023-07-25 13:09:09 -07:00
Jake VanderPlas
e1a1377cde replace use of has_opaque_dtype 2023-07-24 14:46:58 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Matthew Johnson
9ddef5cf84 make _dot_general_batch_rule handle python builtin numeric types 2023-07-24 14:01:07 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
65751bb328 make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-20 09:21:55 -07:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
jax authors
6c699815bc Merge pull request #16718 from mattjj:scatter-apply-autodiff
PiperOrigin-RevId: 548144853
2023-07-14 09:24:12 -07:00
jax authors
0e538e559d Merge pull request #16713 from gnecula:poly_clean4
PiperOrigin-RevId: 548092798
2023-07-14 04:52:17 -07:00
Jake VanderPlas
2cfffb613f make scatter-apply jvp notimplemented (for now...)
cf. #16684

Co-authored-by: Jake Vanderplas <vanderplas@google.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-13 16:25:25 -07:00
jax authors
ed302cbdda Merge pull request #16685 from axch:ragged-jit
PiperOrigin-RevId: 547833923
2023-07-13 10:03:06 -07:00
George Necula
71ac0bb446 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 16:37:53 +03:00
George Necula
58d6c4c1ec Roll back #16689
PiperOrigin-RevId: 547773322
2023-07-13 06:05:50 -07:00
George Necula
d21a667235 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 09:59:41 +03:00
Alexey Radul
f97db31ead Fix type errors caught by pytype. 2023-07-12 15:11:14 -04:00
Alexey Radul
ef9f1cbec3 Force bint-typed arrays to int32 types underneath.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-12 10:56:22 -04:00
Alexey Radul
60bec7a17b Physical HLO sharding for bint is the same as for the base type. 2023-07-11 15:21:55 -04:00
Alexey Radul
924394297b Test and implement slicing not dropping raggedness information. 2023-07-11 15:21:55 -04:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
Alexey Radul
5077807c8b Abstract out and reuse the gather_shape_computation to predict which axes will end up ragged.
This should resolve worries about silently wrong metadata about
pile_mapped gather, but gather is complicated so it's hard to be sure.
2023-07-07 09:23:33 -04:00
Alexey Radul
9fdc14f0bf More type annotations, and make transpose_ragged_axes a top-level function instead of a method.
Keep move_stacked_axis as a method because it's a type-specific
version of a top-level function of the same name that already exists.
2023-07-07 09:23:33 -04:00
Alexey Radul
89dd69ea2d Test and implement ragged slicing.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
  q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:33 -04:00
Alexey Radul
6f09fe840e Better error message when broadcasting ragged to static shape.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:29 -04:00
Sharad Vikram
c446b42522 Add discharge rules for scan/while 2023-07-06 22:30:35 +00:00
Jake VanderPlas
7c0334ce15 DOC: improve documentation for lax slicing routines 2023-07-06 10:44:08 -07:00
jax authors
658e8ff3dd Merge pull request #16601 from gnecula:clean_api
PiperOrigin-RevId: 545137395
2023-07-02 23:58:12 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Jake VanderPlas
a329f8b947 schur: fix broken jvp rule 2023-06-30 02:30:25 -07:00
Jake VanderPlas
18bbc96279 Fix integer overflow in gather batching rule 2023-06-27 21:45:45 -07:00
Roy Frostig
14f32653a1 resolve conditionals to default "shared operand form" more often
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.

This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.

PiperOrigin-RevId: 543912445
2023-06-27 18:49:16 -07:00
George Necula
cb42fae810 [shape_poly] Shape polymorphism support for approx_top_k
PiperOrigin-RevId: 543633818
2023-06-26 22:02:41 -07:00
Parker Schuh
819f731e8d jax.lax.collapse now takes Nones for stop_dimension.
PiperOrigin-RevId: 543598626
2023-06-26 18:30:34 -07:00
George Necula
c6a60054b9 [shape_poly] linalg.schur: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543533821
2023-06-26 13:59:01 -07:00
George Necula
a91412e1e7 [shape_poly] linalg.triangular_solve: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543506845
2023-06-26 12:13:12 -07:00
George Necula
ea0e50f765 [shape_poly] Refactor support for dynamic shapes for linalg.eig and linalg.eigh
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.

PiperOrigin-RevId: 543495381
2023-06-26 11:31:31 -07:00
George Necula
2299f05b8b [shape_poly] Cleanup the evaluation of dynamic shapes
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:

```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```

Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
2023-06-25 18:20:50 +02:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
63415a9184 Merge pull request #16386 from axch:ragged-einsum
PiperOrigin-RevId: 542887557
2023-06-23 10:00:07 -07:00