20080 Commits

Author SHA1 Message Date
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
jax authors
f37e5037d5 Merge pull request #20175 from Micky774:array_api
PiperOrigin-RevId: 622040353
2024-04-04 19:23:42 -07:00
jax authors
8111f38c50 Merge pull request #20593 from google:upstream-nightly
PiperOrigin-RevId: 622039156
2024-04-04 19:12:59 -07:00
Jake VanderPlas
c34b118972 Restore upstream-nightly github action 2024-04-04 19:00:55 -07:00
jax authors
a5b8ce1208 Merge pull request #20598 from froystig:changelog
PiperOrigin-RevId: 622021033
2024-04-04 17:43:13 -07:00
Yash Katariya
55233a0029 device_local_layout can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the device_local_layout.
Skip the error checks when arr.layout.device_local_layout is None.

PiperOrigin-RevId: 622007598
2024-04-04 16:42:27 -07:00
Roy Frostig
f247822977 changelog: note doc change to use jax.random.key over PRNGKey 2024-04-04 16:38:08 -07:00
Roy Frostig
2a36d75285 changelog: batching rule change for rng_bit_generator 2024-04-04 16:34:10 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2 Update from_dlpack to match array API 2023 2024-04-04 22:51:25 +00:00
Yash Katariya
b322d399e1 Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted
PiperOrigin-RevId: 621991853
2024-04-04 15:39:41 -07:00
jax authors
033992867f Merge pull request #20588 from olupton:add-gpu-init
PiperOrigin-RevId: 621971232
2024-04-04 14:28:48 -07:00
jax authors
354c8aa0fb Merge pull request #20583 from pearu:pearu/numpy-1-assert_allclose
PiperOrigin-RevId: 621931219
2024-04-04 12:14:53 -07:00
jax authors
015edd56ad Merge pull request #20062 from Micky774:strip_ir
PiperOrigin-RevId: 621930992
2024-04-04 12:04:55 -07:00
Jake VanderPlas
685d97f5e9 Remove support for complex jnp.floor_divide
It is not well-defined, and the implementation is currently untested. Both Python and NumPy raise an error when attempting complex-valued floor division.

PiperOrigin-RevId: 621918604
2024-04-04 11:25:17 -07:00
jax authors
23b0fa8d82 Merge pull request #20572 from mattjj:marray-you
PiperOrigin-RevId: 621878367
2024-04-04 09:12:08 -07:00
Olli Lupton
c97d955771 cuInit before querying compute capability 2024-04-04 15:27:57 +00:00
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
Pearu Peterson
2ef5bc6075 Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities. 2024-04-04 11:19:57 +03:00
Yash Katariya
52f7de0969 Remove the unused return from prepare_axis_resources
PiperOrigin-RevId: 621738698
2024-04-03 22:39:42 -07:00
jax authors
bc0eff588a Update XLA dependency to use revision
55cdde97c2.

PiperOrigin-RevId: 621734632
2024-04-03 22:16:47 -07:00
jax authors
29a2762b64 Merge pull request #20558 from carlosgmartin:mish
PiperOrigin-RevId: 621708823
2024-04-03 19:50:58 -07:00
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
d790c88da9 Rename layout.AUTO to DeviceLocalLayout.AUTO
PiperOrigin-RevId: 621684185
2024-04-03 17:23:35 -07:00
Jieying Luo
783d5d2e14 [PJRT C API] Plumb plugin attributes from plugin to JAX python.
Also add a method for the plugin to return an xla_version plugin attribute.

Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins.

PiperOrigin-RevId: 621672421
2024-04-03 16:31:25 -07:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
jax authors
bfdf3c4b5d Merge pull request #20573 from mattjj:make-jaxpr-effects-simplify
PiperOrigin-RevId: 621656942
2024-04-03 15:29:44 -07:00
Matthew Johnson
e682fa8fdd small simplification to asymptotic complexity of make_jaxpr_effects 2024-04-03 14:44:41 -07:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Dinghua Li
026f309dcb Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels.
PiperOrigin-RevId: 621611672
2024-04-03 13:02:19 -07:00
jax authors
0624775f3a Merge pull request #20561 from superbobry:docs
PiperOrigin-RevId: 621608577
2024-04-03 12:52:40 -07:00
jax authors
2df89b2184 Merge pull request #20569 from jakevdp:fix-ks-test
PiperOrigin-RevId: 621608564
2024-04-03 12:42:05 -07:00
jax authors
fed7efd730 Merge pull request #20571 from hawkinsp:release
PiperOrigin-RevId: 621592689
jax-v0.4.26 jax-v0.4.26-rc
2024-04-03 11:45:56 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
31e2358887 test: work around issue with kstest in scipy>1.12 2024-04-03 11:17:56 -07:00
Dinghua Li
9bb3f79be5 Avoid unnecessary fori_loop when calculating the block indices.
PiperOrigin-RevId: 621580324
2024-04-03 11:08:45 -07:00
jax authors
85cb16936c Merge pull request #20562 from olupton:bind-to-all
PiperOrigin-RevId: 621562851
2024-04-03 10:27:14 -07:00
jax authors
57ee6b7550 Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
2024-04-03 10:17:11 -07:00
Peter Hawkins
1db4af1c3c Use scipy 1.13.0 instead of scipy 1.13.0rc1 in CI.
scipy 1.13.0 was just released.

PiperOrigin-RevId: 621552425
2024-04-03 09:45:09 -07:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
Peter Hawkins
e2f47748e3 Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.

PiperOrigin-RevId: 621532796
2024-04-03 08:35:20 -07:00
Olli Lupton
2dd1b3d6c8 jax.distributed.initialize: specify bind address.
By default, the coordinator process listens on all interfaces.
2024-04-03 17:13:27 +02:00
jax authors
d89f0d6684 Merge pull request #20534 from gnecula:callback_64bit
PiperOrigin-RevId: 621507392
2024-04-03 06:49:23 -07:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
Sergei Lebedev
ea8e393c0e Fixed a few typos in the matmul example in "Pallas Design" 2024-04-03 10:46:05 +01:00
jax authors
dcd45c8d20 Update XLA dependency to use revision
4e8e23f16b.

PiperOrigin-RevId: 621401398
2024-04-02 22:35:23 -07:00