19136 Commits

Author SHA1 Message Date
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
jax authors
d4660a0972 Update XLA dependency to use revision
5f5fff92b0.

PiperOrigin-RevId: 601954633
2024-01-26 22:38:23 -08:00
Junwhan Ahn
0f760ee545 Avoid using lambda as the reducer fn
Lambdas are represented by their ids in the metadata of lowered HLO (see example below) and they change every time. This makes the compilation cache less effective as it causes the computation's fingerprint to change every time.

```
get-tuple-element.41724 = bf16[8]{0} get-tuple-element(reduce.41723), index=0, metadata={op_name="pjit(_wrapped_fn)/jit(main)/.../reduce[computation=<function _compute_argminmax.<locals>.reducer_fn at 0x7fa6ecfb2200> dimensions=(1,)]" source_file="..." source_line=...}
```

PiperOrigin-RevId: 601910715
2024-01-26 17:43:57 -08:00
jax authors
ccfe9c1ec2 Merge pull request #19540 from mattjj:remove-hypothesis-test-dependence
PiperOrigin-RevId: 601908297
2024-01-26 17:28:26 -08:00
jax authors
c42305a0a9 Merge pull request #19536 from jakevdp:key-reuse-cond
PiperOrigin-RevId: 601900128
2024-01-26 16:43:44 -08:00
Matthew Johnson
54d7d5c91c make hypothesis dependence optional 2024-01-26 16:31:01 -08:00
Jake VanderPlas
17935aff01 [key reuse] fix key reuse type for cond with sources 2024-01-26 14:42:55 -08:00
Roy Frostig
2478f311d3 remove key array's isinstance-overriding metaclass
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.

PiperOrigin-RevId: 601815616
2024-01-26 11:16:56 -08:00
Jake VanderPlas
592809cf57 Roll-back 1ae054b003088d873902fa62cfa8099260471e16 to re-enable nextafter tests
Reverts 1ae054b003088d873902fa62cfa8099260471e16

PiperOrigin-RevId: 601814205
2024-01-26 11:08:43 -08:00
Sergei Lebedev
cc5f565b89 Ported a subset of binary operations to lower directly to Triton IR
PiperOrigin-RevId: 601806008
2024-01-26 10:57:01 -08:00
jax authors
8c050ac71e Merge pull request #19517 from ppham27:changelist/601457375
PiperOrigin-RevId: 601804604
2024-01-26 10:49:00 -08:00
jax authors
d0008fbe4a Merge pull request #19511 from jakevdp:fix-asarray
PiperOrigin-RevId: 601803214
2024-01-26 10:40:59 -08:00
jax authors
269ad9fa35 Merge pull request #19504 from jakevdp:full-like-device
PiperOrigin-RevId: 601803117
2024-01-26 10:32:53 -08:00
Sergei Lebedev
cb7a32a844 Fixed a bug in _reduction_lowering
The block argument of tt.reduce is always parameterized by scalars.

Note that this bug had no effect on the emitted Triton IR, because the
lowering code does not currently rely on avals.

PiperOrigin-RevId: 601801294
2024-01-26 10:24:08 -08:00
Sergei Lebedev
273cb27047 compat.tensor __*__ methods no longer do implicit broadcasting
This change makes it simpler to lower binary operations to Triton IR
bypassing Triton Python bindings.

PiperOrigin-RevId: 601796719
2024-01-26 10:13:51 -08:00
jax authors
2a8ce9ae9c Merge pull request #19518 from jakevdp:softmax
PiperOrigin-RevId: 601796039
2024-01-26 10:04:48 -08:00
Jake VanderPlas
9549c745af jnp.full_like & co: support device parameter 2024-01-26 10:01:54 -08:00
Jake VanderPlas
d989f502fd lax.asarray: avoid explicit device_put 2024-01-26 09:46:09 -08:00
Jake VanderPlas
a282d586b6 nn.softmax: use double-where when where is specified 2024-01-26 09:45:31 -08:00
Jake VanderPlas
1ae054b003 Temporarily disable flaky nextafter tests
These are currently failing at HEAD due to 72f10f7eb5

We can re-enable once b9483d30a7 is integrated.

PiperOrigin-RevId: 601788984
2024-01-26 09:36:07 -08:00
Philip Pham
3fc72d1f44 Fix jax.lax.fori_loop(..., unroll=True) with non-positive length 2024-01-26 17:06:30 +00:00
Sergei Lebedev
f34bcc326b Fixed a typo in Pallas GPU lowering
`abs` is not available in `triton.compat.math`.

PiperOrigin-RevId: 601709135
2024-01-26 02:33:53 -08:00
jax authors
890155246d Update XLA dependency to use revision
62156ca9ef.

PiperOrigin-RevId: 601679131
2024-01-25 23:37:10 -08:00
jax authors
70ea84d67f Merge pull request #19485 from ROCmSoftwarePlatform:rocm-enable_tridiagonal_solve
PiperOrigin-RevId: 601613417
2024-01-25 17:19:00 -08:00
jax authors
1264700e73 Merge pull request #19520 from jakevdp:fold-in-consume
PiperOrigin-RevId: 601609582
2024-01-25 17:01:28 -08:00
Jake VanderPlas
b069c20e56 [key reuse] don't consume key in fold_in
Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
2024-01-25 15:35:51 -08:00
jax authors
45daced7c9 Merge pull request #19507 from jakevdp:wraps-implements
PiperOrigin-RevId: 601505827
2024-01-25 11:15:36 -08:00
jax authors
a6f26306b3 Update XLA dependency to use revision
56977c4a88.

PiperOrigin-RevId: 601343380
2024-01-24 22:42:37 -08:00
Oleg Shyshkov
fb80d2abcb [XLA][NFC] Make interface of module loaders consistent.
LoadModuleFromData has (data, format, config, ...) signature while FromFile has (path, config, format, ...). Change the latter so `format` becomes the second argument in both cases.

Since I'm touching this file:
* Use `std::string_view` and `absl::Status`
* Change `ovr_config` parameter to `const &`

PiperOrigin-RevId: 601304308
2024-01-24 19:16:43 -08:00
Skye Wanderman-Milne
5fb8b29dd9 Remove xla_gpu_cuda_data_dir compile option from persistent cache key.
See the new comment for justification.

PiperOrigin-RevId: 601282907
2024-01-24 17:26:59 -08:00
Michael Levesque-Dion
ebfce197ea Emit dense arrays for StableHLO ops migrating to dense arrays
We are migrating some attrs on some StableHLO ops to use DenseI64ArrayAttr instead of DenseIntElementsAttr. Using DenseI64ArrayAttr enforces that the attr values are 1-dimensional and provides nicer APIs. (see https://github.com/openxla/stablehlo/issues/1578 for additional context)

Unfortunately, we have to duplicate the `dense_int_array` function because we migrated the ops in batches. We can't use the existing `dense_int_array` function because it would produce arrays for ops that hadn't yet been migrated. This PR makes the final batch of changes, so no additional methods should be added going forward.

We also have to introduce a new `dense_bool_array` function, with a similar version check.

When the minimum supported jaxlib version uses a recent enough version of StableHLO  (v6 or above), it will be possible to remove the version checks and remove the duplicated `dense_int_array_v6` function.

PiperOrigin-RevId: 601271749
2024-01-24 16:41:37 -08:00
jax authors
c7425ef967 Merge pull request #19509 from jakevdp:array-register
PiperOrigin-RevId: 601264991
2024-01-24 16:14:49 -08:00
Sergei Lebedev
f15cad4651 Lower a subset of math primitives directly to Triton IR
Note that all primitives are now lowered to libdevice calls. Previously,
some of them were lowered to the MLIR arith dialect, and some to libdevice
calls, without any apparent reason for doing so.

PiperOrigin-RevId: 601259707
2024-01-24 15:55:09 -08:00
Jake VanderPlas
78f27dfa9d Remove unnecessary Array.register 2024-01-24 14:59:25 -08:00
Jieying Luo
cfb6250158 Add build instructions to build jaxlib with cuda plugin from source.
PiperOrigin-RevId: 601231525
2024-01-24 14:15:54 -08:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
Sharad Vikram
4646c64f54 [Pallas/TPU] Add support for input/output aliasing
PiperOrigin-RevId: 601219571
2024-01-24 13:37:19 -08:00
jax authors
8b81555850 Merge pull request #19470 from jakevdp:device-arg
PiperOrigin-RevId: 601205471
2024-01-24 12:50:37 -08:00
Jake VanderPlas
d55cd7c9e2 jax.numpy: support device argument for full, empty, zeros, ones 2024-01-24 12:01:09 -08:00
jax authors
831c25ff8c Merge pull request #19497 from ROCmSoftwarePlatform:rocm-log-pytest-html-report
PiperOrigin-RevId: 601188997
2024-01-24 11:51:28 -08:00
jax authors
b83f2b2595 Merge pull request #19503 from jakevdp:np2-signatures
PiperOrigin-RevId: 601177940
2024-01-24 11:15:08 -08:00
Jake VanderPlas
cedd67d611 Test: add weights to unsupported arguments 2024-01-24 10:44:51 -08:00
jax authors
322f8b22bd Merge pull request #19495 from 8bitmp3:patch-1
PiperOrigin-RevId: 601150181
2024-01-24 09:48:29 -08:00
zahiqbal
ef7694f26a [ROCM]: Generating pytest html logs from unit-tests. 2024-01-24 15:08:35 +00:00
8bitmp3
df3cc491ce
Update JAX docs copyright 2024-01-24 14:19:45 +00:00
Yash Katariya
a63197fed8 Add an internal _device_list parameter to GSPMDSharding so that we can save on the initialization cost of PyDeviceList when creating GSPMDSharding from other shardings
PiperOrigin-RevId: 601055733
2024-01-24 02:29:22 -08:00
jax authors
a74b04a43f Merge pull request #19492 from gnecula:poly_tests
PiperOrigin-RevId: 601050430
2024-01-24 02:04:25 -08:00
George Necula
0bd511d621 [shape_poly] Add more tests for reasoning about inequalities.
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as

   _expect(best=v1, current=v2)

to document that the current logic produces `v2` but the best value
we can hope for is `v1`.

This PR also adds more support for profiling tests.
2024-01-24 09:57:49 +01:00
jax authors
eab0dd1901 Update XLA dependency to use revision
627d8c3f9d.

PiperOrigin-RevId: 601014027
2024-01-23 22:52:12 -08:00
Yash Katariya
6f96c963ff Preserve single device NamedSharding/PositionalSharding on the output instead of always return SingleDeviceShardings.
Fixes https://github.com/google/jax/issues/19459

PiperOrigin-RevId: 600999853
2024-01-23 21:29:14 -08:00