4825 Commits

Author SHA1 Message Date
Jake VanderPlas
340e655ac2 Remove deprecated sym_pos argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Sharad Vikram
8fbcfce2dd [Pallas] Enable interpreter mode as default lowering for CPU
PiperOrigin-RevId: 580700740
2023-11-08 16:35:31 -08:00
jax authors
21260a7a65 Merge pull request #18396 from nouiz:custom_gpu_ops
PiperOrigin-RevId: 580680809
2023-11-08 15:39:22 -08:00
jax authors
6f1e488e54 Merge pull request #18402 from XuehaiPan:fix-hashable-partial
PiperOrigin-RevId: 580675684
2023-11-08 15:21:36 -08:00
jax authors
62e8f4d4aa Merge pull request #18431 from jakevdp:hist-ravel
PiperOrigin-RevId: 580597288
2023-11-08 11:37:15 -08:00
jax authors
1f8be3031a Merge pull request #18439 from jakevdp:dep-nn-normalize
PiperOrigin-RevId: 580597252
2023-11-08 11:28:56 -08:00
jax authors
6efcfe8fe0 Merge pull request #18386 from shacklettbp:pallas
PiperOrigin-RevId: 580561989
2023-11-08 09:48:27 -08:00
Jake VanderPlas
d59c1f1e21 jax.nn.normalize: deprecate using standard framework 2023-11-08 09:42:23 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00
Jake VanderPlas
a30d51ba2e jnp.histogram: avoid flattening input 2023-11-08 08:55:09 -08:00
Brennan Shacklett
aec0df9c3f [Pallas] Preserve order of grid provided to pallas_call, append and reverse vmapped dimensions 2023-11-07 18:23:07 -08:00
Enrique Piqueras
2b6be33ab9 Small Pallas bug fixes.
- Fix non determinism in axis name mapping in Pallas lowering.
- Implement mesh info support in Pallas transform jaxprs so blockspecs can access mesh info.
- Fix non compatibility of mesh info and scalar prefetching in Pallas lowering.
- Fix Pallas lowering of multi-dimensional topology remote DMAs.

PiperOrigin-RevId: 580357651
2023-11-07 17:55:19 -08:00
Peter Hawkins
647afbba3c Fix test for distributed backend initialization.
By default, _backends is {}, not None. So the "backends are initialized" test always failed.

PiperOrigin-RevId: 580262416
2023-11-07 12:20:44 -08:00
Adam Paszke
e66f4e94c4 [Mosaic] Add support for extracting the first element of a vector as a scalar
PiperOrigin-RevId: 580169469
2023-11-07 07:20:48 -08:00
Peter Hawkins
b85ea68fba Move test for backend initialization into jax.distributed.initialize() wrapper.
This allows us to skip the check for tests.

PiperOrigin-RevId: 580168674
2023-11-07 07:12:40 -08:00
Jake VanderPlas
e7932858d4 Fix dtype-related docs for cumsum/cumprod 2023-11-06 14:31:25 -08:00
jax authors
7e372944f9 Fix the missing cache_misses metric when min compile time is set to zero.
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
2023-11-06 14:04:35 -08:00
Peter Hawkins
eeafff5891 Raise an exception if jax.distributed.initialize() is called after backends have been initialized.
Fixes https://github.com/google/jax/issues/18237

PiperOrigin-RevId: 579936065
2023-11-06 13:12:26 -08:00
jax authors
28e33ca5d1 Switch to the new cache-key generation algorithm.
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.

Testing: test workloads.
PiperOrigin-RevId: 579928158
2023-11-06 12:57:01 -08:00
jax authors
29a6262d11 Merge pull request #18380 from jakevdp:cross-2d
PiperOrigin-RevId: 579909035
2023-11-06 11:45:36 -08:00
jax authors
189d3aba2d Merge pull request #18379 from jakevdp:nanmean-dtype
PiperOrigin-RevId: 579908683
2023-11-06 11:37:34 -08:00
jax authors
79ca40ea05 Merge pull request #18397 from Xodarap:add-beta-function
PiperOrigin-RevId: 579861286
2023-11-06 09:05:11 -08:00
Peter Hawkins
390022a227 [JAX] Stop using a custom ducc kernel: instead just emit an Fft HLO operation and let XLA emit the call to ducc.
XLA now calls ducc itself as of da67903a4c, so we don't need a custom call in JAX any more. In addition, the DUCC call from XLA receives a thread pool and is parallelized.

Fixes https://github.com/google/jax/issues/14664

PiperOrigin-RevId: 579829580
2023-11-06 06:55:06 -08:00
Xuehai Pan
de6fbdc69c Fix hashcode for HashablePartial to get equal hashes for equal objects 2023-11-06 17:27:19 +08:00
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
Frederic Bastien
1f7df8008e Better error message. 2023-11-05 08:17:33 -08:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
Brennan Shacklett
094579910f [Pallas]: Fix kernel grid dimensions that are too large in Y and Z 2023-11-03 17:17:42 -07:00
Jake VanderPlas
4f863e9148 jnp.cross: account for numpy 2.0 deprecation 2023-11-03 14:15:23 -07:00
Jake VanderPlas
6d6c12a14d jnp.nanmean: avoid integer overflow for large arrays 2023-11-03 13:15:18 -07:00
jax authors
808289d52a Merge pull request #18373 from jakevdp:sharding-doc
PiperOrigin-RevId: 579224607
2023-11-03 10:23:44 -07:00
Jake VanderPlas
cd3ea05665 Ensure sharding-related array properties are documented 2023-11-03 09:56:33 -07:00
jax authors
db07f40233 Fall-back to original device/backend hashing if topology-desc is unavailable.
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.

Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
jax authors
62741d9744 Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Parker Schuh
c8b7c1b80b Remove temporary flag for forcing arg tuplization of lowered functions.
PiperOrigin-RevId: 578910366
2023-11-02 10:53:16 -07:00
jax authors
1f6264896d Merge pull request #18354 from jakevdp:dep-opaque
PiperOrigin-RevId: 578904025
2023-11-02 10:37:47 -07:00
jax authors
6d5eaa6ec3 Merge pull request #18295 from gnecula:lax_multi
PiperOrigin-RevId: 578892592
2023-11-02 10:07:34 -07:00
Jake VanderPlas
0111dcbda3 Finish deprecation of allow_opaque_dtype 2023-11-02 09:51:06 -07:00
Etienne Pot
81ac67f381 Fix typing annotations for @jax.named_call
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Reed Wanderman-Milne
d41078fb95 Properly pack and unpack int4 arrays on CPU in PJRT.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
jax authors
a009f8d6c1 Pass flags from kernel into HLO backend config.
PiperOrigin-RevId: 578390868
2023-10-31 21:32:19 -07:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Roy Frostig
b22e75716f add threefry_partitionable config setting to thread-local JIT context 2023-10-31 13:45:49 -07:00
Yash Katariya
20255dce84 Delete cached_call_jaxpr_lowerings since a more general cached_primitive_lowerings is available
PiperOrigin-RevId: 577993595
2023-10-30 16:38:57 -07:00
Yash Katariya
85af862efd [Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035

PiperOrigin-RevId: 577974380
2023-10-30 15:28:43 -07:00
Jake VanderPlas
fbacebc11e jnp.einsum: mention default value for optimize param 2023-10-30 09:22:37 -07:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
Parker Schuh
19c65353d2 Do not init backends from topology construction, instead directly init the
plugin.

PiperOrigin-RevId: 577331743
2023-10-27 16:21:01 -07:00
Yash Katariya
8ee58117e2 Don't print all the devices in the mesh during ResourceEnv's repr. Just print the mesh shape.
PiperOrigin-RevId: 577305337
2023-10-27 14:25:34 -07:00