21462 Commits

Author SHA1 Message Date
jax authors
348cbba6b2 Merge pull request #21991 from rajasekharporeddy:testbranch4
PiperOrigin-RevId: 645770273
2024-06-22 21:28:10 -07:00
jax authors
c5a1a02b44 Merge pull request #21966 from selamw1:complexobj_doc
PiperOrigin-RevId: 645770064
2024-06-22 21:24:21 -07:00
selamw1
7fb7ea2732 iscomplexobj_docstr_added
iscomplexobj_docstr_fixed

iscomplexobj_docstr_char_fixed

lint_and_typecheck_fixed

lint_white_space_fixed

lint_white_space_fixed_unfinished_docstring_removed

lint_white_space_fixed
2024-06-22 14:09:34 -07:00
jax authors
8038e2b00c Update XLA dependency to use revision
13611f856a.

PiperOrigin-RevId: 645710907
2024-06-22 13:11:43 -07:00
rajasekharporeddy
bad1610ac4 Improved docs for jnp.polyint and jnp.polyder 2024-06-22 23:01:23 +05:30
jax authors
300d06a505 Merge pull request #22042 from hawkinsp:numpy
PiperOrigin-RevId: 645519817
2024-06-21 15:33:50 -07:00
jax authors
56e8fe630e Merge pull request #22028 from rajasekharporeddy:stats-sem
PiperOrigin-RevId: 645518083
2024-06-21 15:30:04 -07:00
jax authors
8c7e0d4265 Merge pull request #21973 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 645517834
2024-06-21 15:26:13 -07:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
jax authors
fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00
rajasekharporeddy
c5de7bb92e Improve docs for jnp.poly and polyval 2024-06-22 02:49:43 +05:30
Keith Rush
694cafb72b Minimizes defensive psum in shard_map transpose with check_rep=False.
By summing up over fewer things, this version should be more numerically stable.

PiperOrigin-RevId: 645499243
2024-06-21 14:18:26 -07:00
rajasekharporeddy
edde7d9762 Fix the behavior of jax.scipy.stats.sem when keepdims=True 2024-06-22 02:39:00 +05:30
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Tomás Longeri
a730f6bfd3 [Mosaic][infer-vector-layout] Allow non-32-bit types for vector.extract_strided_slice
PiperOrigin-RevId: 645481424
2024-06-21 13:17:37 -07:00
jax authors
79d30f682b Update XLA dependency to use revision
c9995d69b6.

PiperOrigin-RevId: 645477860
2024-06-21 13:04:50 -07:00
jax authors
47ab52d34f Merge pull request #22014 from selamw1:iscomplex_doc
PiperOrigin-RevId: 645465397
2024-06-21 12:19:24 -07:00
selamw1
400bcbb59d iscomplex_docstring_added
iscomplex_docstring_summary_in_one_line
2024-06-21 11:49:11 -07:00
jax authors
4a7b293bd9 Merge pull request #22027 from rajasekharporeddy:testbranch5
PiperOrigin-RevId: 645437879
2024-06-21 10:51:05 -07:00
rajasekharporeddy
8cb5fb5f7c Add code examples to jax.scipy.stats.mode docs 2024-06-21 22:12:48 +05:30
jax authors
c313cac964 Merge pull request #22032 from dfm:example-examples-docstrings
PiperOrigin-RevId: 645416034
2024-06-21 09:39:34 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
jax authors
534d32a24d Merge pull request #22030 from gnecula:fix_platform_dependet
PiperOrigin-RevId: 645392458
2024-06-21 08:08:25 -07:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Sharad Vikram
1a056823cf Check for static grid dimensions when partitioning nondivisible grid dimensions
PiperOrigin-RevId: 645253793
2024-06-20 21:30:35 -07:00
Sharad Vikram
1eb215eb87 Relax condition for partitioning dynamic grid dimensions over cores in pipeline emitter
PiperOrigin-RevId: 645240166
2024-06-20 20:14:56 -07:00
Sharad Vikram
787e747364 [Pallas/TPU] Add support for uneven partitioning over cores in emit_pipeline
PiperOrigin-RevId: 645173988
2024-06-20 15:25:49 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Kyle Lucke
84d748f43c Stop using xla/statusor.h now that it just contains an alias for absl::Status.
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.

PiperOrigin-RevId: 645169743
2024-06-20 15:09:40 -07:00
Kyle Lucke
80de35514c Replace xla::Status with absl::Status in cuda_plugin_extension.cc.
PiperOrigin-RevId: 645144526
2024-06-20 13:51:03 -07:00
jax authors
cf41d11c2f Update XLA dependency to use revision
5261477acb.

PiperOrigin-RevId: 645140537
2024-06-20 13:37:54 -07:00
Peter Hawkins
637a7cbcc1 pjit.py cleanups.
Refactoring only, NFC intended.

* add types to more places.
* don't unpack PjitInfo positionally, since it's a 23-tuple and that seems rather error prone.
* change _infer_params to produce a new PjitParams NamedTuple, rather than having callers unpack a 9-tuple positionally.
* inline _pjit_jaxpr into its caller, since it only has one caller and the wrapper doesn't really clarify anything.
* note the return type of transformation_with_aux is a Callable.

PiperOrigin-RevId: 645068326
2024-06-20 09:58:22 -07:00
Adam Paszke
97ce128313 [Mosaic GPU] Call mbarrier.try_wait only once mbarrier.test_wait fails
The llvm.expect intrinsic puts the loop at the end of the program, allowing
the whole barrier to be compiled to a test_wait + predicated branch that is
immediately followed by the continuation. This seems to make the happy path
a little faster which can help reduce the barrier overhead for compute-bound
kernels.

PiperOrigin-RevId: 645007019
2024-06-20 06:24:27 -07:00
jax authors
f2956a4002 [Mosaic GPU] Add kwargs support to profiler.measure.
PiperOrigin-RevId: 644984111
2024-06-20 04:54:02 -07:00
Greg Olechwierowicz
274245ceac [JAX] Fix FDO profile deserialization.
Passing .c_str() to the ParseFromString can lead to inconsistent behavior when c string is not properly null terminated. This diff initializes an std::string explicitly by providing a size of a buffer to be parsed.

PiperOrigin-RevId: 644979040
2024-06-20 04:31:34 -07:00
Adam Paszke
fbdc50878c [Mosaic GPU] Simplify the flash attention kernel by never replicating the accumulator
I can't replicate any performance gains anymore and it's making it unnecessarily
more complicated.

PiperOrigin-RevId: 644977685
2024-06-20 04:23:17 -07:00
Adam Paszke
f976f1f224 [Mosaic GPU] Use explicit WGMMA/ALU scheduling in the flash attention kernel
With this change we reach state of the art performance (as far as I can tell)
of 50%+ TC util for head_dim 128 and 256.

I also added a little tuning harness to try out different block sizes.

PiperOrigin-RevId: 644927079
2024-06-20 00:56:44 -07:00
Paweł Paruzel
63aab133f1 Port LU Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 644845277
2024-06-19 17:31:25 -07:00
jax authors
2ac1cfada9 Update XLA dependency to use revision
4960accc10.

PiperOrigin-RevId: 644807999
2024-06-19 13:21:18 -07:00
Chris Jones
de8fd3b00d [mosaic:gpu] Fix MLIR canonicalization pass region-simplify option.
`region-simplify` now has `normal` and `aggressive` modes (using `normal` for now).

PiperOrigin-RevId: 644724434
2024-06-19 06:02:11 -07:00
Yash Katariya
d3bfd32667 Remove jax.xla_computation tests from jax2tf. api_test.py has enough coverage for jax.xla_computation
PiperOrigin-RevId: 644605636
2024-06-18 21:01:52 -07:00
Jevin Jiang
103d620856 [Pallas] Expose internal scratch size config.
This flag is useful - user can increase the internal scratch size to get more efficient pl.roll op and relayout in Mosaic.

PiperOrigin-RevId: 644576369
2024-06-18 18:32:30 -07:00
Jevin Jiang
3a3014b5a8 [Pallas] Add support for dynamic shift in roll op.
PiperOrigin-RevId: 644536582
2024-06-18 15:52:23 -07:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
Sharad Vikram
39ffc29b8a Add support to pipeline emitter for shapes that don't perfectly divide the block shapes
Reverts 1669b99505deecdced51527b4d9f8041a1745bec

PiperOrigin-RevId: 644511841
2024-06-18 14:26:31 -07:00
Jevin Jiang
cac1791f7c [XLA:Mosaic] Support dynamic roll
We will choose the best solution based on the size of internal scratch memory.
- Sol 1: Convert dynamic roll to Log(N) static ops
- Sol 2: Static Store + Dynamic Load with internal scratch

PiperOrigin-RevId: 644509328
2024-06-18 14:18:56 -07:00
Jevin Jiang
c180b86bbd [XLA:Mosaic] Fix ext rule with large native tile.
PiperOrigin-RevId: 644495447
2024-06-18 13:34:40 -07:00
Sergei Lebedev
598c686d5f _make_dispatch_table calls in Pallas GPU lowering are now slightly more compact
PiperOrigin-RevId: 644492759
2024-06-18 13:25:54 -07:00
jax authors
3dae4096d3 Update XLA dependency to use revision
8361e4bfcb.

PiperOrigin-RevId: 644482850
2024-06-18 12:54:04 -07:00
Sergei Lebedev
2bb80d540c Removed unused which_linear= param from pallas_call_p
As far as I can tell, it was threaded through everywhere, but never actually used.

PiperOrigin-RevId: 644457293
2024-06-18 11:31:54 -07:00