jax authors
7b73e8ecf7
Merge pull request #22083 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 646550598
2024-06-25 11:36:26 -07:00
vfdev-5
70b4823348
Updated jnp.ceil/floor/trunc to preserve int dtypes
...
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
- For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Piseth Ky
cc80f63521
better floor_division doc
...
corrected equivalent operator scope, added See Also:
del whitespace
added operator equivalence note
2024-06-25 11:08:22 -07:00
jax authors
2bdfd0cebb
Merge pull request #22081 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 646539354
2024-06-25 11:05:33 -07:00
Jake VanderPlas
c099f714cc
Fix bug in bcoo_spdot_general abstract_eval
2024-06-25 10:44:55 -07:00
rajasekharporeddy
ee0bf776fe
Update jnp.max/amax docs
2024-06-25 23:07:59 +05:30
rajasekharporeddy
b975ca1932
Better doc for jnp.sum
2024-06-25 22:25:23 +05:30
Jake VanderPlas
fa73077146
jax.config: validate on set()
2024-06-25 09:02:32 -07:00
Kevin Gleason
a4c92a454b
Clean up gather/scatter StableHLO lowering.
...
PiperOrigin-RevId: 646491586
2024-06-25 08:39:50 -07:00
jax authors
50407e536e
Fix a small comparison bug. This allows every last reserved byte to be used, and also avoids nonsensical error messages of the form: "Requested more bytes than we reserved space for: X > X".
...
PiperOrigin-RevId: 646487396
2024-06-25 08:25:48 -07:00
George Necula
0b0d5a08ea
[pallas] Simplify the generation of the hlo.CustomCall
...
We use the `mlir.custom_call` library that most other
JAX lowering rules use.
We remove the use of `hlo.TupleType` because custom calls
support multiple results. Tuples are not supported
by the `mlir.custom_call` and they complicate the logic
for input_output_aliasing and for the dynamically-shaped
results.
PiperOrigin-RevId: 646462934
2024-06-25 06:57:03 -07:00
Junwhan Ahn
817eb7a9ee
Skip broadcast_one_to_all
for single-process JAX execution
...
There is no need to actually perform exchange across processes if there's only one process.
PiperOrigin-RevId: 646309814
2024-06-24 20:15:32 -07:00
Yash Katariya
15ed2a8bcd
Fix device_put
of a scalar with PositionalSharding
...
Fixes https://github.com/google/jax/issues/22073
PiperOrigin-RevId: 646279569
2024-06-24 17:53:32 -07:00
Piseth Ky
ee9290b9fd
updating copysign docstring
...
grammar fix
grammar fix
2024-06-24 15:14:05 -07:00
jax authors
eec6b49488
Merge pull request #21984 from pkgoogle:better_rint_doc
...
PiperOrigin-RevId: 646230695
2024-06-24 14:57:19 -07:00
Piseth Ky
0808824bad
updating rint doc
...
mentions output always being promoted to inexact
2024-06-24 14:00:30 -07:00
Matthew Johnson
789a0c7999
faster inline jaxpr, but maybe we shouldnt inline at all
2024-06-24 20:21:57 +00:00
jax authors
3501ad9a8d
Merge pull request #22066 from jakevdp:binop-boilerplate
...
PiperOrigin-RevId: 646188935
2024-06-24 12:45:40 -07:00
jax authors
7f5f771370
Merge pull request #21978 from gnecula:exp_platforms
...
PiperOrigin-RevId: 646162753
2024-06-24 11:25:02 -07:00
Jake VanderPlas
0522276d62
Internal: remove binary operation ufunc factories
2024-06-24 11:22:28 -07:00
Justin Fu
8ba8f3bf65
[Pallas] Implement block-invariant sampling.
...
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Jake VanderPlas
549ac2d865
Internal: remove one_to_one_unop factory
2024-06-24 10:20:38 -07:00
Jake VanderPlas
a43994d464
Fix type annotations for jnp.poly* functions
2024-06-24 09:44:50 -07:00
jax authors
e119fe933b
[Mosaic GPU] Allow __init__.py to run without _src.lib.mosaic_gpu being available.
...
PiperOrigin-RevId: 646124431
2024-06-24 09:37:32 -07:00
jax authors
e0b2144000
Merge pull request #22056 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 646118091
2024-06-24 09:18:12 -07:00
rajasekharporeddy
da334e37d0
Add code examples to jax.scipy.stats.sem docs
2024-06-24 20:47:49 +05:30
rajasekharporeddy
eba891e3fa
Improve docs for jnp.roots and jnp.polyfit
2024-06-24 16:39:55 +05:30
Yimei Sun
b37f51487d
Remove the blocking for float16 dot on CPU platform to take advantage of CPU
...
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
2024-06-23 23:51:30 -07:00
jax authors
8c602cc3d0
Merge pull request #20731 from NeilGirdhar:softmax
...
PiperOrigin-RevId: 645774372
2024-06-22 21:51:22 -07:00
jax authors
348cbba6b2
Merge pull request #21991 from rajasekharporeddy:testbranch4
...
PiperOrigin-RevId: 645770273
2024-06-22 21:28:10 -07:00
George Necula
d737abda48
[export] Fix multi-platform lowering for unknown platform, with donated_argnums
...
I had to ensure that the check for platforms supporting donation
only kicks in when we actually have donation.
2024-06-23 07:26:12 +03:00
jax authors
c5a1a02b44
Merge pull request #21966 from selamw1:complexobj_doc
...
PiperOrigin-RevId: 645770064
2024-06-22 21:24:21 -07:00
selamw1
7fb7ea2732
iscomplexobj_docstr_added
...
iscomplexobj_docstr_fixed
iscomplexobj_docstr_char_fixed
lint_and_typecheck_fixed
lint_white_space_fixed
lint_white_space_fixed_unfinished_docstring_removed
lint_white_space_fixed
2024-06-22 14:09:34 -07:00
rajasekharporeddy
bad1610ac4
Improved docs for jnp.polyint and jnp.polyder
2024-06-22 23:01:23 +05:30
Neil Girdhar
56fdb42e9d
Copy nn.{softmax,log_softmax} to scipy.special
2024-06-22 09:32:14 -04:00
jax authors
56e8fe630e
Merge pull request #22028 from rajasekharporeddy:stats-sem
...
PiperOrigin-RevId: 645518083
2024-06-21 15:30:04 -07:00
jax authors
8c7e0d4265
Merge pull request #21973 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 645517834
2024-06-21 15:26:13 -07:00
Ruturaj4
f787941ade
[ROCM] Fix rocm platform name in nm.py
2024-06-21 16:28:55 -05:00
rajasekharporeddy
c5de7bb92e
Improve docs for jnp.poly and polyval
2024-06-22 02:49:43 +05:30
Keith Rush
694cafb72b
Minimizes defensive psum
in shard_map transpose with check_rep=False
.
...
By summing up over fewer things, this version should be more numerically stable.
PiperOrigin-RevId: 645499243
2024-06-21 14:18:26 -07:00
rajasekharporeddy
edde7d9762
Fix the behavior of jax.scipy.stats.sem when keepdims=True
2024-06-22 02:39:00 +05:30
Peter Hawkins
9e30079dba
[JAX] Add caching to pjit._infer_params.
...
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
jax authors
47ab52d34f
Merge pull request #22014 from selamw1:iscomplex_doc
...
PiperOrigin-RevId: 645465397
2024-06-21 12:19:24 -07:00
selamw1
400bcbb59d
iscomplex_docstring_added
...
iscomplex_docstring_summary_in_one_line
2024-06-21 11:49:11 -07:00
jax authors
4a7b293bd9
Merge pull request #22027 from rajasekharporeddy:testbranch5
...
PiperOrigin-RevId: 645437879
2024-06-21 10:51:05 -07:00
rajasekharporeddy
8cb5fb5f7c
Add code examples to jax.scipy.stats.mode docs
2024-06-21 22:12:48 +05:30
Dan Foreman-Mackey
6d35b109fd
Rename "Example" to "Examples" in docstrings.
...
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
George Necula
6e3fc9a768
Fix the eager mode execution for lax.platform_dependent
...
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Sharad Vikram
1a056823cf
Check for static grid dimensions when partitioning nondivisible grid dimensions
...
PiperOrigin-RevId: 645253793
2024-06-20 21:30:35 -07:00
Sharad Vikram
1eb215eb87
Relax condition for partitioning dynamic grid dimensions over cores in pipeline emitter
...
PiperOrigin-RevId: 645240166
2024-06-20 20:14:56 -07:00