20041 Commits

Author SHA1 Message Date
jax authors
fed7efd730 Merge pull request #20571 from hawkinsp:release
PiperOrigin-RevId: 621592689
jax-v0.4.26 jax-v0.4.26-rc
2024-04-03 11:45:56 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Dinghua Li
9bb3f79be5 Avoid unnecessary fori_loop when calculating the block indices.
PiperOrigin-RevId: 621580324
2024-04-03 11:08:45 -07:00
jax authors
85cb16936c Merge pull request #20562 from olupton:bind-to-all
PiperOrigin-RevId: 621562851
2024-04-03 10:27:14 -07:00
jax authors
57ee6b7550 Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
2024-04-03 10:17:11 -07:00
Peter Hawkins
1db4af1c3c Use scipy 1.13.0 instead of scipy 1.13.0rc1 in CI.
scipy 1.13.0 was just released.

PiperOrigin-RevId: 621552425
2024-04-03 09:45:09 -07:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
Peter Hawkins
e2f47748e3 Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.

PiperOrigin-RevId: 621532796
2024-04-03 08:35:20 -07:00
Olli Lupton
2dd1b3d6c8 jax.distributed.initialize: specify bind address.
By default, the coordinator process listens on all interfaces.
2024-04-03 17:13:27 +02:00
jax authors
d89f0d6684 Merge pull request #20534 from gnecula:callback_64bit
PiperOrigin-RevId: 621507392
2024-04-03 06:49:23 -07:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
jax authors
dcd45c8d20 Update XLA dependency to use revision
4e8e23f16b.

PiperOrigin-RevId: 621401398
2024-04-02 22:35:23 -07:00
Peter Hawkins
1159691fc6 [JAX] Update JAX CI dockerfiles to use NumPy 2.0.0rc1, SciPy 1.13.0rc1, and ml_dtypes 0.4.0.
Change in preparation for releasing JAX with NumPy 2.0 support.

PiperOrigin-RevId: 621354875
2024-04-02 18:26:30 -07:00
Peter Hawkins
b7401872d5 Disable a random distribution test that appears to fail under scipy 1.13.0rc1.
PiperOrigin-RevId: 621352047
2024-04-02 18:11:30 -07:00
jax authors
6b582f5977 Merge pull request #20552 from jakevdp:geomspace-complex
PiperOrigin-RevId: 621348396
2024-04-02 17:54:51 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Peter Hawkins
1baed9b285 [PJRT:CPU] Replace references to pjrt/tfrt_cpu_pjrt_client with pjrt/cpu/cpu_client.h.
The two are aliases and the former is a forwarding header pointing to the latter.

Cleanup only, no functional changes.

PiperOrigin-RevId: 621341188
2024-04-02 17:20:16 -07:00
Sharad Vikram
318ae8935a [Pallas TPU] Relax windowing restriction when lowering mapped grids
PiperOrigin-RevId: 621330022
2024-04-02 16:32:39 -07:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
jax authors
a54eb81d78 Merge pull request #20548 from jakevdp:uint-floordiv
PiperOrigin-RevId: 621292364
2024-04-02 14:14:55 -07:00
jax authors
e282bf57db Merge pull request #20536 from jakevdp:broadcast-to
PiperOrigin-RevId: 621287464
2024-04-02 13:59:12 -07:00
Jake VanderPlas
e99a3051ed jnp.floor_div: lower directly to div for unsigned int 2024-04-02 13:47:42 -07:00
jax authors
3c8081231f Merge pull request #20542 from jakevdp:solve-test
PiperOrigin-RevId: 621275261
2024-04-02 13:31:21 -07:00
Sharad Vikram
87aee90e67 Fix typo in Pallas design
PiperOrigin-RevId: 621275025
2024-04-02 13:20:46 -07:00
Jake VanderPlas
37a34bf6a3 test: remove solve test case that is invalid under NumPy 2.0 2024-04-02 13:10:26 -07:00
jax authors
87b869ddd8 Merge pull request #20543 from jakevdp:skip-geomspace
PiperOrigin-RevId: 621267031
2024-04-02 12:53:28 -07:00
Jake VanderPlas
c79f54f7ff test: skip complex geomspace test under numpy 2.0 2024-04-02 12:32:02 -07:00
jax authors
fb55d59143 This CL introduces 'PluginProgram' in IFRT and exposes this in python via xla_client.compile_ifrt_program().
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.

PiperOrigin-RevId: 621258245
2024-04-02 12:20:35 -07:00
Peter Hawkins
60458bb36a Fail gracefully in lobpcg_test if matplotlib isn't installed.
There isn't yet a matplotlib that supports NumPy 2.0, so we need to support running tests without it.

PiperOrigin-RevId: 621228727
2024-04-02 10:48:46 -07:00
jax authors
00489be23d Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
2024-04-02 08:56:37 -07:00
Jake VanderPlas
6de6983d59 jnp.broadcast_to: better error for invalid shape 2024-04-02 08:38:51 -07:00
Sergei Lebedev
2ee4c0f644 Added installation instructions to the error in _pallas_call_lowering
PiperOrigin-RevId: 621168804
2024-04-02 07:36:28 -07:00
jax authors
4c41c12e21 Merge pull request #20514 from gnecula:callback_cache
PiperOrigin-RevId: 621160168
2024-04-02 06:55:45 -07:00
George Necula
bff24c6d6f [callback] Improve caching effectiveness in presence of callbacks.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
2024-04-02 15:33:24 +02:00
jax authors
431015a14e Merge pull request #20383 from gnecula:doc_deprecation
PiperOrigin-RevId: 621153196
2024-04-02 06:19:58 -07:00
jax authors
b3fe9400fb Add round lowering rule.
PiperOrigin-RevId: 621110036
2024-04-02 02:55:34 -07:00
George Necula
84db689e39 A few more comments about how the deprecations work 2024-04-02 10:52:01 +02:00
jax authors
1d221d1f14 Update XLA dependency to use revision
7ff97ee607.

PiperOrigin-RevId: 621065700
2024-04-01 23:24:17 -07:00
George Necula
c491720ee1 Accelerate deprecation of jax.experimental.host_callback.id_print and stop_outfeed_receiver
`jax.experimental.host_callback` is deprecated and any API in that module will throw a DeprecationWarning. After this change the `id_print` and `stop_outfeed_receiver` will throw an `AttributeError` in internal code only.

Add a deprecation message for `barrier_wait`.

PiperOrigin-RevId: 621064083
2024-04-01 23:12:59 -07:00
Sergei Lebedev
16b3f00e42 Register GPU/TPU lowering for pallas_call_p lazily
Prior to this change we had to import jax.experimental.pallas.{gpu,tpu} in
jax.experimental.pallas only to get the lowering rules registered.

PiperOrigin-RevId: 620957622
2024-04-01 14:40:33 -07:00
jax authors
5a7e874339 Merge pull request #20524 from jakevdp:trapz
PiperOrigin-RevId: 620953434
2024-04-01 14:26:40 -07:00
Sergei Lebedev
c4f1a45205 Generalized the in_specs/out_specs types in PrefetchScalarGridSpec
PiperOrigin-RevId: 620949269
2024-04-01 14:11:55 -07:00
Yash Katariya
6557f680fd Rename SpecifiedLayout to DeviceLocalLayout
PiperOrigin-RevId: 620934348
2024-04-01 13:19:46 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Peter Hawkins
011ced4431 Guard test that requires two devices with device_count() check.
PiperOrigin-RevId: 620921563
2024-04-01 12:32:54 -07:00
jax authors
a4cccd98c9 Merge pull request #20520 from google:dependabot/github_actions/actions/setup-python-5.1.0
PiperOrigin-RevId: 620889663
2024-04-01 10:47:57 -07:00
jax authors
5ece588fc3 Merge pull request #20474 from rajasekharporeddy:test_branch3
PiperOrigin-RevId: 620888823
2024-04-01 10:38:06 -07:00
dependabot[bot]
7a57b39a35
Bump actions/setup-python from 5.0.0 to 5.1.0
Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.0.0 to 5.1.0.
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](0a5c615913...82c7e631bb)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-01 17:31:33 +00:00
jax authors
3e9bb51d5f Merge pull request #20426 from google:dependabot/github_actions/actions/cache-4.0.2
PiperOrigin-RevId: 620878220
2024-04-01 10:05:40 -07:00