19208 Commits

Author SHA1 Message Date
Jake VanderPlas
84ee045f55 [key reuse] handle polymorphic shapes in slice 2024-01-29 13:59:44 -08:00
jax authors
cd1332eec6 Removes extra parens from sharding comparison check (the original code always produced 'false').
PiperOrigin-RevId: 602482106
2024-01-29 13:23:19 -08:00
Roy Frostig
a04332504b remove PRNGKeyArray ABC
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
2024-01-29 12:41:26 -08:00
jax authors
37b6d22a82 Merge pull request #19515 from gnecula:poly_decision
PiperOrigin-RevId: 602465741
2024-01-29 12:23:51 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
jax authors
c7a1a095cf Merge pull request #19556 from google:dependabot/github_actions/actions/upload-artifact-4.3.0
PiperOrigin-RevId: 602436797
2024-01-29 11:00:10 -08:00
jax authors
9c6574a51c Merge pull request #19555 from google:dependabot/github_actions/styfle/cancel-workflow-action-0.12.1
PiperOrigin-RevId: 602436583
2024-01-29 10:51:57 -08:00
jax authors
dc10b622ce Merge pull request #19554 from jakevdp:fix-ogrid-test
PiperOrigin-RevId: 602433842
2024-01-29 10:43:21 -08:00
dependabot[bot]
39422ca9cb
Bump actions/upload-artifact from 4.2.0 to 4.3.0
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.2.0 to 4.3.0.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](694cdabd8b...26f96dfa69)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-29 17:51:10 +00:00
dependabot[bot]
2b075fab75
Bump styfle/cancel-workflow-action from 0.12.0 to 0.12.1
Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action) from 0.12.0 to 0.12.1.
- [Release notes](https://github.com/styfle/cancel-workflow-action/releases)
- [Commits](01ce38bf96...85880fa030)

---
updated-dependencies:
- dependency-name: styfle/cancel-workflow-action
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-29 17:51:05 +00:00
Tomás Longeri
1cbda65afc [Mosaic][NFC] Fix comment format
PiperOrigin-RevId: 602416317
2024-01-29 09:47:41 -08:00
Jake VanderPlas
bd5e9bef33 testOgrid: make test compatible with NumPy 2.0 2024-01-29 09:21:23 -08:00
jax authors
0d152dcfab Merge pull request #19528 from superbobry:strict-abc
PiperOrigin-RevId: 602392902
2024-01-29 08:18:50 -08:00
Sergei Lebedev
fad3e749a1 Migrated remaining operations from the math namespace to lower directly to Triton IR
PiperOrigin-RevId: 602390761
2024-01-29 08:10:03 -08:00
Sergei Lebedev
07f8f700ca Migrated atomic operations to lower directly to Triton IR
PiperOrigin-RevId: 602384705
2024-01-29 07:45:31 -08:00
George Necula
e20afac46a [shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".

Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:

  * if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
    eliminate "a" and infer the derived constraint "b + c >= 0".
  * the lower bound of "a + c", in presence of a constraint "a >= b"
    it greater-or-equal to "b + c".

The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.

This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.

The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.

With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.

We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-29 17:26:35 +02:00
jax authors
2518a6f6d2 Merge pull request #19551 from superbobry:main
PiperOrigin-RevId: 602379291
2024-01-29 07:20:07 -08:00
Sergei Lebedev
3cea57d9d1 Slice.from_slice now works for slices with a negative start index
The implementation still requires the step to be 1, so any slice
with a negative start index has size 0.
2024-01-29 13:13:06 +00:00
Sergei Lebedev
078bb00fdb Replaced most usages of abc.ABC with util.StrictABC
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.

The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
jax authors
f910a0d4f8 Update XLA dependency to use revision
2a0061b67b.

PiperOrigin-RevId: 602106890
2024-01-27 21:50:14 -08:00
jax authors
8bbcbb6e12 Merge pull request #19532 from mattjj:jax-attrs2
PiperOrigin-RevId: 602079647
2024-01-27 18:07:04 -08:00
Matthew Johnson
22160dfe65 add test 2024-01-27 17:44:43 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
jax authors
d4660a0972 Update XLA dependency to use revision
5f5fff92b0.

PiperOrigin-RevId: 601954633
2024-01-26 22:38:23 -08:00
Junwhan Ahn
0f760ee545 Avoid using lambda as the reducer fn
Lambdas are represented by their ids in the metadata of lowered HLO (see example below) and they change every time. This makes the compilation cache less effective as it causes the computation's fingerprint to change every time.

```
get-tuple-element.41724 = bf16[8]{0} get-tuple-element(reduce.41723), index=0, metadata={op_name="pjit(_wrapped_fn)/jit(main)/.../reduce[computation=<function _compute_argminmax.<locals>.reducer_fn at 0x7fa6ecfb2200> dimensions=(1,)]" source_file="..." source_line=...}
```

PiperOrigin-RevId: 601910715
2024-01-26 17:43:57 -08:00
jax authors
ccfe9c1ec2 Merge pull request #19540 from mattjj:remove-hypothesis-test-dependence
PiperOrigin-RevId: 601908297
2024-01-26 17:28:26 -08:00
jax authors
c42305a0a9 Merge pull request #19536 from jakevdp:key-reuse-cond
PiperOrigin-RevId: 601900128
2024-01-26 16:43:44 -08:00
Matthew Johnson
54d7d5c91c make hypothesis dependence optional 2024-01-26 16:31:01 -08:00
Jake VanderPlas
17935aff01 [key reuse] fix key reuse type for cond with sources 2024-01-26 14:42:55 -08:00
Roy Frostig
2478f311d3 remove key array's isinstance-overriding metaclass
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.

PiperOrigin-RevId: 601815616
2024-01-26 11:16:56 -08:00
Jake VanderPlas
592809cf57 Roll-back 1ae054b003088d873902fa62cfa8099260471e16 to re-enable nextafter tests
Reverts 1ae054b003088d873902fa62cfa8099260471e16

PiperOrigin-RevId: 601814205
2024-01-26 11:08:43 -08:00
Sergei Lebedev
cc5f565b89 Ported a subset of binary operations to lower directly to Triton IR
PiperOrigin-RevId: 601806008
2024-01-26 10:57:01 -08:00
jax authors
8c050ac71e Merge pull request #19517 from ppham27:changelist/601457375
PiperOrigin-RevId: 601804604
2024-01-26 10:49:00 -08:00
jax authors
d0008fbe4a Merge pull request #19511 from jakevdp:fix-asarray
PiperOrigin-RevId: 601803214
2024-01-26 10:40:59 -08:00
jax authors
269ad9fa35 Merge pull request #19504 from jakevdp:full-like-device
PiperOrigin-RevId: 601803117
2024-01-26 10:32:53 -08:00
Sergei Lebedev
cb7a32a844 Fixed a bug in _reduction_lowering
The block argument of tt.reduce is always parameterized by scalars.

Note that this bug had no effect on the emitted Triton IR, because the
lowering code does not currently rely on avals.

PiperOrigin-RevId: 601801294
2024-01-26 10:24:08 -08:00
Sergei Lebedev
273cb27047 compat.tensor __*__ methods no longer do implicit broadcasting
This change makes it simpler to lower binary operations to Triton IR
bypassing Triton Python bindings.

PiperOrigin-RevId: 601796719
2024-01-26 10:13:51 -08:00
jax authors
2a8ce9ae9c Merge pull request #19518 from jakevdp:softmax
PiperOrigin-RevId: 601796039
2024-01-26 10:04:48 -08:00
Jake VanderPlas
9549c745af jnp.full_like & co: support device parameter 2024-01-26 10:01:54 -08:00
Jake VanderPlas
d989f502fd lax.asarray: avoid explicit device_put 2024-01-26 09:46:09 -08:00
Jake VanderPlas
a282d586b6 nn.softmax: use double-where when where is specified 2024-01-26 09:45:31 -08:00
Jake VanderPlas
1ae054b003 Temporarily disable flaky nextafter tests
These are currently failing at HEAD due to 72f10f7eb5

We can re-enable once b9483d30a7 is integrated.

PiperOrigin-RevId: 601788984
2024-01-26 09:36:07 -08:00
Philip Pham
3fc72d1f44 Fix jax.lax.fori_loop(..., unroll=True) with non-positive length 2024-01-26 17:06:30 +00:00
Sergei Lebedev
f34bcc326b Fixed a typo in Pallas GPU lowering
`abs` is not available in `triton.compat.math`.

PiperOrigin-RevId: 601709135
2024-01-26 02:33:53 -08:00
jax authors
890155246d Update XLA dependency to use revision
62156ca9ef.

PiperOrigin-RevId: 601679131
2024-01-25 23:37:10 -08:00
jax authors
70ea84d67f Merge pull request #19485 from ROCmSoftwarePlatform:rocm-enable_tridiagonal_solve
PiperOrigin-RevId: 601613417
2024-01-25 17:19:00 -08:00
jax authors
1264700e73 Merge pull request #19520 from jakevdp:fold-in-consume
PiperOrigin-RevId: 601609582
2024-01-25 17:01:28 -08:00
Jake VanderPlas
b069c20e56 [key reuse] don't consume key in fold_in
Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
2024-01-25 15:35:51 -08:00
jax authors
45daced7c9 Merge pull request #19507 from jakevdp:wraps-implements
PiperOrigin-RevId: 601505827
2024-01-25 11:15:36 -08:00
jax authors
a6f26306b3 Update XLA dependency to use revision
56977c4a88.

PiperOrigin-RevId: 601343380
2024-01-24 22:42:37 -08:00