13610 Commits

Author SHA1 Message Date
jax authors
7b73e8ecf7 Merge pull request #22083 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 646550598
2024-06-25 11:36:26 -07:00
vfdev-5
70b4823348 Updated jnp.ceil/floor/trunc to preserve int dtypes
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Piseth Ky
cc80f63521 better floor_division doc
corrected equivalent operator scope, added See Also:

del whitespace

added operator equivalence note
2024-06-25 11:08:22 -07:00
jax authors
2bdfd0cebb Merge pull request #22081 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 646539354
2024-06-25 11:05:33 -07:00
Jake VanderPlas
c099f714cc Fix bug in bcoo_spdot_general abstract_eval 2024-06-25 10:44:55 -07:00
rajasekharporeddy
ee0bf776fe Update jnp.max/amax docs 2024-06-25 23:07:59 +05:30
rajasekharporeddy
b975ca1932 Better doc for jnp.sum 2024-06-25 22:25:23 +05:30
Jake VanderPlas
fa73077146 jax.config: validate on set() 2024-06-25 09:02:32 -07:00
Kevin Gleason
a4c92a454b Clean up gather/scatter StableHLO lowering.
PiperOrigin-RevId: 646491586
2024-06-25 08:39:50 -07:00
jax authors
50407e536e Fix a small comparison bug. This allows every last reserved byte to be used, and also avoids nonsensical error messages of the form: "Requested more bytes than we reserved space for: X > X".
PiperOrigin-RevId: 646487396
2024-06-25 08:25:48 -07:00
George Necula
0b0d5a08ea [pallas] Simplify the generation of the hlo.CustomCall
We use the `mlir.custom_call` library that most other
JAX lowering rules use.

We remove the use of `hlo.TupleType` because custom calls
support multiple results. Tuples are not supported
by the `mlir.custom_call` and they complicate the logic
for input_output_aliasing and for the dynamically-shaped
results.

PiperOrigin-RevId: 646462934
2024-06-25 06:57:03 -07:00
Junwhan Ahn
817eb7a9ee Skip broadcast_one_to_all for single-process JAX execution
There is no need to actually perform exchange across processes if there's only one process.

PiperOrigin-RevId: 646309814
2024-06-24 20:15:32 -07:00
Yash Katariya
15ed2a8bcd Fix device_put of a scalar with PositionalSharding
Fixes https://github.com/google/jax/issues/22073

PiperOrigin-RevId: 646279569
2024-06-24 17:53:32 -07:00
Piseth Ky
ee9290b9fd updating copysign docstring
grammar fix

grammar fix
2024-06-24 15:14:05 -07:00
jax authors
eec6b49488 Merge pull request #21984 from pkgoogle:better_rint_doc
PiperOrigin-RevId: 646230695
2024-06-24 14:57:19 -07:00
Piseth Ky
0808824bad updating rint doc
mentions output always being promoted to inexact
2024-06-24 14:00:30 -07:00
Matthew Johnson
789a0c7999 faster inline jaxpr, but maybe we shouldnt inline at all 2024-06-24 20:21:57 +00:00
jax authors
3501ad9a8d Merge pull request #22066 from jakevdp:binop-boilerplate
PiperOrigin-RevId: 646188935
2024-06-24 12:45:40 -07:00
jax authors
7f5f771370 Merge pull request #21978 from gnecula:exp_platforms
PiperOrigin-RevId: 646162753
2024-06-24 11:25:02 -07:00
Jake VanderPlas
0522276d62 Internal: remove binary operation ufunc factories 2024-06-24 11:22:28 -07:00
Justin Fu
8ba8f3bf65 [Pallas] Implement block-invariant sampling.
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Jake VanderPlas
549ac2d865 Internal: remove one_to_one_unop factory 2024-06-24 10:20:38 -07:00
Jake VanderPlas
a43994d464 Fix type annotations for jnp.poly* functions 2024-06-24 09:44:50 -07:00
jax authors
e119fe933b [Mosaic GPU] Allow __init__.py to run without _src.lib.mosaic_gpu being available.
PiperOrigin-RevId: 646124431
2024-06-24 09:37:32 -07:00
jax authors
e0b2144000 Merge pull request #22056 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 646118091
2024-06-24 09:18:12 -07:00
rajasekharporeddy
da334e37d0 Add code examples to jax.scipy.stats.sem docs 2024-06-24 20:47:49 +05:30
rajasekharporeddy
eba891e3fa Improve docs for jnp.roots and jnp.polyfit 2024-06-24 16:39:55 +05:30
Yimei Sun
b37f51487d Remove the blocking for float16 dot on CPU platform to take advantage of CPU
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
2024-06-23 23:51:30 -07:00
jax authors
8c602cc3d0 Merge pull request #20731 from NeilGirdhar:softmax
PiperOrigin-RevId: 645774372
2024-06-22 21:51:22 -07:00
jax authors
348cbba6b2 Merge pull request #21991 from rajasekharporeddy:testbranch4
PiperOrigin-RevId: 645770273
2024-06-22 21:28:10 -07:00
George Necula
d737abda48 [export] Fix multi-platform lowering for unknown platform, with donated_argnums
I had to ensure that the check for platforms supporting donation
only kicks in when we actually have donation.
2024-06-23 07:26:12 +03:00
jax authors
c5a1a02b44 Merge pull request #21966 from selamw1:complexobj_doc
PiperOrigin-RevId: 645770064
2024-06-22 21:24:21 -07:00
selamw1
7fb7ea2732 iscomplexobj_docstr_added
iscomplexobj_docstr_fixed

iscomplexobj_docstr_char_fixed

lint_and_typecheck_fixed

lint_white_space_fixed

lint_white_space_fixed_unfinished_docstring_removed

lint_white_space_fixed
2024-06-22 14:09:34 -07:00
rajasekharporeddy
bad1610ac4 Improved docs for jnp.polyint and jnp.polyder 2024-06-22 23:01:23 +05:30
Neil Girdhar
56fdb42e9d Copy nn.{softmax,log_softmax} to scipy.special 2024-06-22 09:32:14 -04:00
jax authors
56e8fe630e Merge pull request #22028 from rajasekharporeddy:stats-sem
PiperOrigin-RevId: 645518083
2024-06-21 15:30:04 -07:00
jax authors
8c7e0d4265 Merge pull request #21973 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 645517834
2024-06-21 15:26:13 -07:00
Ruturaj4
f787941ade [ROCM] Fix rocm platform name in nm.py 2024-06-21 16:28:55 -05:00
rajasekharporeddy
c5de7bb92e Improve docs for jnp.poly and polyval 2024-06-22 02:49:43 +05:30
Keith Rush
694cafb72b Minimizes defensive psum in shard_map transpose with check_rep=False.
By summing up over fewer things, this version should be more numerically stable.

PiperOrigin-RevId: 645499243
2024-06-21 14:18:26 -07:00
rajasekharporeddy
edde7d9762 Fix the behavior of jax.scipy.stats.sem when keepdims=True 2024-06-22 02:39:00 +05:30
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
jax authors
47ab52d34f Merge pull request #22014 from selamw1:iscomplex_doc
PiperOrigin-RevId: 645465397
2024-06-21 12:19:24 -07:00
selamw1
400bcbb59d iscomplex_docstring_added
iscomplex_docstring_summary_in_one_line
2024-06-21 11:49:11 -07:00
jax authors
4a7b293bd9 Merge pull request #22027 from rajasekharporeddy:testbranch5
PiperOrigin-RevId: 645437879
2024-06-21 10:51:05 -07:00
rajasekharporeddy
8cb5fb5f7c Add code examples to jax.scipy.stats.mode docs 2024-06-21 22:12:48 +05:30
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Sharad Vikram
1a056823cf Check for static grid dimensions when partitioning nondivisible grid dimensions
PiperOrigin-RevId: 645253793
2024-06-20 21:30:35 -07:00
Sharad Vikram
1eb215eb87 Relax condition for partitioning dynamic grid dimensions over cores in pipeline emitter
PiperOrigin-RevId: 645240166
2024-06-20 20:14:56 -07:00