20056 Commits

Author SHA1 Message Date
carlosgmartin
e98612e2ab Add where argument to logsumexp. 2024-04-08 12:57:06 -04:00
jax authors
29a2762b64 Merge pull request #20558 from carlosgmartin:mish
PiperOrigin-RevId: 621708823
2024-04-03 19:50:58 -07:00
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
d790c88da9 Rename layout.AUTO to DeviceLocalLayout.AUTO
PiperOrigin-RevId: 621684185
2024-04-03 17:23:35 -07:00
Jieying Luo
783d5d2e14 [PJRT C API] Plumb plugin attributes from plugin to JAX python.
Also add a method for the plugin to return an xla_version plugin attribute.

Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins.

PiperOrigin-RevId: 621672421
2024-04-03 16:31:25 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
jax authors
bfdf3c4b5d Merge pull request #20573 from mattjj:make-jaxpr-effects-simplify
PiperOrigin-RevId: 621656942
2024-04-03 15:29:44 -07:00
Matthew Johnson
e682fa8fdd small simplification to asymptotic complexity of make_jaxpr_effects 2024-04-03 14:44:41 -07:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Dinghua Li
026f309dcb Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels.
PiperOrigin-RevId: 621611672
2024-04-03 13:02:19 -07:00
jax authors
0624775f3a Merge pull request #20561 from superbobry:docs
PiperOrigin-RevId: 621608577
2024-04-03 12:52:40 -07:00
jax authors
2df89b2184 Merge pull request #20569 from jakevdp:fix-ks-test
PiperOrigin-RevId: 621608564
2024-04-03 12:42:05 -07:00
jax authors
fed7efd730 Merge pull request #20571 from hawkinsp:release
PiperOrigin-RevId: 621592689
jax-v0.4.26 jax-v0.4.26-rc
2024-04-03 11:45:56 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
31e2358887 test: work around issue with kstest in scipy>1.12 2024-04-03 11:17:56 -07:00
Dinghua Li
9bb3f79be5 Avoid unnecessary fori_loop when calculating the block indices.
PiperOrigin-RevId: 621580324
2024-04-03 11:08:45 -07:00
jax authors
85cb16936c Merge pull request #20562 from olupton:bind-to-all
PiperOrigin-RevId: 621562851
2024-04-03 10:27:14 -07:00
jax authors
57ee6b7550 Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
2024-04-03 10:17:11 -07:00
Peter Hawkins
1db4af1c3c Use scipy 1.13.0 instead of scipy 1.13.0rc1 in CI.
scipy 1.13.0 was just released.

PiperOrigin-RevId: 621552425
2024-04-03 09:45:09 -07:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
Peter Hawkins
e2f47748e3 Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.

PiperOrigin-RevId: 621532796
2024-04-03 08:35:20 -07:00
Olli Lupton
2dd1b3d6c8 jax.distributed.initialize: specify bind address.
By default, the coordinator process listens on all interfaces.
2024-04-03 17:13:27 +02:00
jax authors
d89f0d6684 Merge pull request #20534 from gnecula:callback_64bit
PiperOrigin-RevId: 621507392
2024-04-03 06:49:23 -07:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
Sergei Lebedev
ea8e393c0e Fixed a few typos in the matmul example in "Pallas Design" 2024-04-03 10:46:05 +01:00
jax authors
dcd45c8d20 Update XLA dependency to use revision
4e8e23f16b.

PiperOrigin-RevId: 621401398
2024-04-02 22:35:23 -07:00
Peter Hawkins
1159691fc6 [JAX] Update JAX CI dockerfiles to use NumPy 2.0.0rc1, SciPy 1.13.0rc1, and ml_dtypes 0.4.0.
Change in preparation for releasing JAX with NumPy 2.0 support.

PiperOrigin-RevId: 621354875
2024-04-02 18:26:30 -07:00
Peter Hawkins
b7401872d5 Disable a random distribution test that appears to fail under scipy 1.13.0rc1.
PiperOrigin-RevId: 621352047
2024-04-02 18:11:30 -07:00
jax authors
6b582f5977 Merge pull request #20552 from jakevdp:geomspace-complex
PiperOrigin-RevId: 621348396
2024-04-02 17:54:51 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Peter Hawkins
1baed9b285 [PJRT:CPU] Replace references to pjrt/tfrt_cpu_pjrt_client with pjrt/cpu/cpu_client.h.
The two are aliases and the former is a forwarding header pointing to the latter.

Cleanup only, no functional changes.

PiperOrigin-RevId: 621341188
2024-04-02 17:20:16 -07:00
Sharad Vikram
318ae8935a [Pallas TPU] Relax windowing restriction when lowering mapped grids
PiperOrigin-RevId: 621330022
2024-04-02 16:32:39 -07:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
jax authors
a54eb81d78 Merge pull request #20548 from jakevdp:uint-floordiv
PiperOrigin-RevId: 621292364
2024-04-02 14:14:55 -07:00
jax authors
e282bf57db Merge pull request #20536 from jakevdp:broadcast-to
PiperOrigin-RevId: 621287464
2024-04-02 13:59:12 -07:00
Jake VanderPlas
e99a3051ed jnp.floor_div: lower directly to div for unsigned int 2024-04-02 13:47:42 -07:00
jax authors
3c8081231f Merge pull request #20542 from jakevdp:solve-test
PiperOrigin-RevId: 621275261
2024-04-02 13:31:21 -07:00
Sharad Vikram
87aee90e67 Fix typo in Pallas design
PiperOrigin-RevId: 621275025
2024-04-02 13:20:46 -07:00
Jake VanderPlas
37a34bf6a3 test: remove solve test case that is invalid under NumPy 2.0 2024-04-02 13:10:26 -07:00
jax authors
87b869ddd8 Merge pull request #20543 from jakevdp:skip-geomspace
PiperOrigin-RevId: 621267031
2024-04-02 12:53:28 -07:00
Jake VanderPlas
c79f54f7ff test: skip complex geomspace test under numpy 2.0 2024-04-02 12:32:02 -07:00
jax authors
fb55d59143 This CL introduces 'PluginProgram' in IFRT and exposes this in python via xla_client.compile_ifrt_program().
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.

PiperOrigin-RevId: 621258245
2024-04-02 12:20:35 -07:00
Peter Hawkins
60458bb36a Fail gracefully in lobpcg_test if matplotlib isn't installed.
There isn't yet a matplotlib that supports NumPy 2.0, so we need to support running tests without it.

PiperOrigin-RevId: 621228727
2024-04-02 10:48:46 -07:00
jax authors
00489be23d Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
2024-04-02 08:56:37 -07:00
Jake VanderPlas
6de6983d59 jnp.broadcast_to: better error for invalid shape 2024-04-02 08:38:51 -07:00
Sergei Lebedev
2ee4c0f644 Added installation instructions to the error in _pallas_call_lowering
PiperOrigin-RevId: 621168804
2024-04-02 07:36:28 -07:00
jax authors
4c41c12e21 Merge pull request #20514 from gnecula:callback_cache
PiperOrigin-RevId: 621160168
2024-04-02 06:55:45 -07:00
George Necula
bff24c6d6f [callback] Improve caching effectiveness in presence of callbacks.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
2024-04-02 15:33:24 +02:00