1401 Commits

Author SHA1 Message Date
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
fe3c8e15a8 Merge pull request #21806 from cgarciae:cond-passthrough-outputs
PiperOrigin-RevId: 646970169
2024-06-26 09:13:07 -07:00
Cristian Garcia
dae7e41ade fix cond passthrough outputs 2024-06-26 16:17:45 +01:00
jax authors
fb4ab2baa1 Merge pull request #22055 from Intel-tensorflow:yimei/remove_block_fp16_oncpu
PiperOrigin-RevId: 646951088
2024-06-26 08:12:42 -07:00
Kevin Gleason
a4c92a454b Clean up gather/scatter StableHLO lowering.
PiperOrigin-RevId: 646491586
2024-06-25 08:39:50 -07:00
Yimei Sun
b37f51487d Remove the blocking for float16 dot on CPU platform to take advantage of CPU
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
2024-06-23 23:51:30 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
jax authors
be1f4ba380 Merge pull request #21905 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
Sergei Lebedev
4913fff971 Rollback #21888, because it breaks multiple internal users
Reverts 193591b5c0b90ce498015b2e3d48950615253380

PiperOrigin-RevId: 643965549
2024-06-17 05:01:04 -07:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
Jake VanderPlas
4f7cd03893 lax.mul: accept boolean inputs 2024-06-14 13:47:11 -07:00
jax authors
c839b268d2 Get rid of the is_hermitian argument for lax.qdwh. If it was known that H was also positive semi-definite, the polar decomposition would be I*H. But for indefinite H, the QDWH algorithm does not differ from the general case for Hermitian inputs.
PiperOrigin-RevId: 643141687
2024-06-13 15:33:49 -07:00
Jake VanderPlas
6b8e2f3467 DOC: jax.lax.top_k: fix docstring rendering & add example 2024-06-10 13:57:21 -07:00
George Necula
dbad518d2b [shape_poly] Add limited support for lax.approx_top_k.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.

We also add a backwards compatibility test.

PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
Matthew Johnson
7c125701c5 make cond forward inputs to outputs, reduces vmap lifting
Co-authored-by: Cristian Garcia <cgarciae@google.com>
2024-06-05 16:39:55 +00:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Yash Katariya
0591620932 Fix copy.deepcopy support for arrays in pinned_host memory.
PiperOrigin-RevId: 639145872
2024-05-31 14:04:02 -07:00
Michael Levesque-Dion
9309592ac3 Integrate StableHLO at openxla/stablehlo@c44d9af8
PiperOrigin-RevId: 638559828
2024-05-30 01:04:35 -07:00
jax authors
f72b0f0ca6 Merge pull request #21504 from gnecula:poly_approx
PiperOrigin-RevId: 638550165
2024-05-30 00:22:24 -07:00
George Necula
c6a47316be [shape_poly] Fixes for approx_top_k when aggregated_to_topk=True
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.

Add a set of test cases for shape polymorphism for approx_top_k.

The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.

The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
2024-05-30 04:17:13 +03:00
George Necula
87b81fc768 [shape_polyO] Add support for jnp.tril. 2024-05-30 02:53:00 +03:00
jax authors
ede94c3c81 Rollback of https://github.com/google/jax/pull/20705
Causing pmap_test.py failures.

Reverts a7bce471440dda2a8bbeed1fe01dd9f733ef5bbc

PiperOrigin-RevId: 638437174
2024-05-29 15:46:55 -07:00
Shuhan Ding
56fed63b70
roll back cumulative reductio lowering on METAL 2024-05-29 13:39:52 -07:00
Matthew Johnson
a4622b6a29 fix weak key cache stuff
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-05-29 17:53:56 +00:00
Dougal
122924fdf3 Make attrs work with pytrees
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-28 23:23:51 -04:00
jax authors
a7bce47144 Merge pull request #20705 from chaserileyroberts:chase/pbroadcast_channel_fix
PiperOrigin-RevId: 637986186
2024-05-28 12:29:40 -07:00
Michael Levesque-Dion
43f51d73ce Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
2024-05-28 10:58:51 -07:00
Chase Roberts
af6970e432 Pipe channel handle 2024-05-28 10:20:50 -07:00
jax authors
720d2b8708 Merge pull request #21376 from ROCm:ci_f8
PiperOrigin-RevId: 637884483
2024-05-28 06:56:26 -07:00
Matthew Johnson
a24b73802f avoid singleton dim in scan lowering when unroll==1 2024-05-25 19:07:49 +00:00
Sergei Lebedev
15b974c90b Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
2024-05-24 08:24:17 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Ruturaj4
41e4c25dc1 [ROCm] Add float8_e4m3fnuz and float8_e5m2fnuz support for Rocm 2024-05-22 05:50:28 +00:00
George Karpenkov
e0a6453a39 Simplify JAX lowering rules for cumulative sum
Upstream fix has landed => removing CPU workaround.

PiperOrigin-RevId: 635505632
2024-05-20 10:52:29 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Dan Foreman-Mackey
09a4b38ae2 Add informative error for invalid unroll in scan
As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an
uninformative `ZeroDivisionError`. This PR adds a check which raises a
`ValueError` for `unroll<=0`.
2024-05-15 15:40:27 -04:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
jax authors
a56eb5681e Merge pull request #21211 from elfiegg:main
PiperOrigin-RevId: 633386931
2024-05-13 17:53:26 -07:00
Peter Hawkins
72a81e58e6 Readd a default lowering rule for cumsum et al.
A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules

PiperOrigin-RevId: 633297839
2024-05-13 12:34:51 -07:00
Elfie Guo
43d19161ac Remove type promotion for mixed fp8 matmuls. 2024-05-13 16:50:52 +00:00
George Karpenkov
de14e3b32e Reverts 49bd4d6f01d6cda00f9b1bdfbda156636baae928
PiperOrigin-RevId: 633221195
2024-05-13 08:35:40 -07:00
jax authors
1fed78499f Merge pull request #20940 from piotrfilipiuk:changelist/623910451
PiperOrigin-RevId: 633170419
2024-05-13 05:03:28 -07:00
Peter Hawkins
49bd4d6f01 Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a
PiperOrigin-RevId: 632818524
2024-05-11 12:24:06 -07:00
piotrfilipiuk
93dfe05aec Implements Ragged Dot API 2024-05-11 06:40:18 -07:00
Yash Katariya
a4693db6cf Add a jaxpr interpreter for propagating memory kinds to output. It only triggers if we detect multiple memory kinds in the jaxpr.
This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals.

It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later.

1. `jit(f, out_shardings=s_host)`

2. ```
   @jax.jit
   def f(x):
     return jax.device_put(x, s_host)
   ```

PiperOrigin-RevId: 632621025
2024-05-10 15:34:57 -07:00
George Karpenkov
586568f4fe Simplify JAX lowering rules for cumulative sum
Rely on XLA decomposition.

# JAX GPU microbenchmarks

285us for cumsum over 1e8 elements

449us for cumsum over 1e8 elements.

# JAX CPU microbenchmarks:

1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements

PiperOrigin-RevId: 632547166
2024-05-10 11:03:28 -07:00