20094 Commits

Author SHA1 Message Date
Yash Katariya
c6804f92d0 Remove deprecated code from JAX lowering and compilation
PiperOrigin-RevId: 622530123
2024-04-06 19:43:29 -07:00
Yash Katariya
3b5980fd73 Share lowering code between jit and aot jit path
PiperOrigin-RevId: 622487044
2024-04-06 13:44:18 -07:00
jax authors
e8b86cd81d Merge pull request #20266 from mattjj:earray
PiperOrigin-RevId: 622485196
2024-04-06 13:29:34 -07:00
Matthew Johnson
89f26db36d start adding EArray, a jax.Array analog that can contain extended dtypes 2024-04-06 13:09:25 -07:00
Sergei Lebedev
9616900cc9 jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold

* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
  of always copying it to the host. Note that the version here still always
  copies to the host.

If this change breaks you, you can recover the old behavior by changing

    jax.pure_callback(
        f,
        result_shape_dtypes,
        *args,
        **kwargs,
    )

to

    jax.pure_callback(
        lambda *args: f(*jax.tree.map(np.asarray, args)),
        result_shape_dtypes,
        *args,
        **kwargs,
    )

so that the callback function is called with NumPy arrays as before.

I will update the "External callbacks" tutorial in a follow up.

PiperOrigin-RevId: 622457378
2024-04-06 09:30:08 -07:00
jax authors
63aee94792 Update XLA dependency to use revision
04e2731152.

PiperOrigin-RevId: 622378182
2024-04-05 23:10:50 -07:00
Yash Katariya
c125442644 Add Layout support to jax.jit.
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 622352537
2024-04-05 20:09:34 -07:00
Jieying Luo
f88139bf67 Add a fallback when GetDefaultLayout is unimplemented for that backend.
PiperOrigin-RevId: 622278710
2024-04-05 14:13:08 -07:00
jax authors
7413894b86 Merge pull request #20599 from mattjj:temp-config-to-disable-custom-vjp-shape-check
PiperOrigin-RevId: 622224003
2024-04-05 10:52:45 -07:00
jax authors
eff8a47fbb Merge pull request #20608 from jakevdp:fix-ufunc
PiperOrigin-RevId: 622216821
2024-04-05 10:30:57 -07:00
jax authors
3ae79ea36d Merge pull request #20516 from gnecula:hcb_new
PiperOrigin-RevId: 622101882
2024-04-05 10:30:29 -07:00
Jake VanderPlas
1f9a2dddb8 ufunc: fix implements wrapper for at 2024-04-05 09:42:49 -07:00
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
jax authors
3ec45e9ef9 Update XLA dependency to use revision
be5c637c7f.

PiperOrigin-RevId: 622076324
2024-04-04 23:01:15 -07:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
jax authors
f37e5037d5 Merge pull request #20175 from Micky774:array_api
PiperOrigin-RevId: 622040353
2024-04-04 19:23:42 -07:00
jax authors
8111f38c50 Merge pull request #20593 from google:upstream-nightly
PiperOrigin-RevId: 622039156
2024-04-04 19:12:59 -07:00
Jake VanderPlas
c34b118972 Restore upstream-nightly github action 2024-04-04 19:00:55 -07:00
Matthew Johnson
3d4687fbfc add a temporary config option to disable custom_vjp shape checking 2024-04-04 18:21:10 -07:00
jax authors
a5b8ce1208 Merge pull request #20598 from froystig:changelog
PiperOrigin-RevId: 622021033
2024-04-04 17:43:13 -07:00
Yash Katariya
55233a0029 device_local_layout can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the device_local_layout.
Skip the error checks when arr.layout.device_local_layout is None.

PiperOrigin-RevId: 622007598
2024-04-04 16:42:27 -07:00
Roy Frostig
f247822977 changelog: note doc change to use jax.random.key over PRNGKey 2024-04-04 16:38:08 -07:00
Roy Frostig
2a36d75285 changelog: batching rule change for rng_bit_generator 2024-04-04 16:34:10 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2 Update from_dlpack to match array API 2023 2024-04-04 22:51:25 +00:00
Yash Katariya
b322d399e1 Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted
PiperOrigin-RevId: 621991853
2024-04-04 15:39:41 -07:00
jax authors
033992867f Merge pull request #20588 from olupton:add-gpu-init
PiperOrigin-RevId: 621971232
2024-04-04 14:28:48 -07:00
jax authors
354c8aa0fb Merge pull request #20583 from pearu:pearu/numpy-1-assert_allclose
PiperOrigin-RevId: 621931219
2024-04-04 12:14:53 -07:00
jax authors
015edd56ad Merge pull request #20062 from Micky774:strip_ir
PiperOrigin-RevId: 621930992
2024-04-04 12:04:55 -07:00
Jake VanderPlas
685d97f5e9 Remove support for complex jnp.floor_divide
It is not well-defined, and the implementation is currently untested. Both Python and NumPy raise an error when attempting complex-valued floor division.

PiperOrigin-RevId: 621918604
2024-04-04 11:25:17 -07:00
jax authors
23b0fa8d82 Merge pull request #20572 from mattjj:marray-you
PiperOrigin-RevId: 621878367
2024-04-04 09:12:08 -07:00
Olli Lupton
c97d955771 cuInit before querying compute capability 2024-04-04 15:27:57 +00:00
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
Pearu Peterson
2ef5bc6075 Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities. 2024-04-04 11:19:57 +03:00
Yash Katariya
52f7de0969 Remove the unused return from prepare_axis_resources
PiperOrigin-RevId: 621738698
2024-04-03 22:39:42 -07:00
jax authors
bc0eff588a Update XLA dependency to use revision
55cdde97c2.

PiperOrigin-RevId: 621734632
2024-04-03 22:16:47 -07:00
jax authors
29a2762b64 Merge pull request #20558 from carlosgmartin:mish
PiperOrigin-RevId: 621708823
2024-04-03 19:50:58 -07:00
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
d790c88da9 Rename layout.AUTO to DeviceLocalLayout.AUTO
PiperOrigin-RevId: 621684185
2024-04-03 17:23:35 -07:00
Jieying Luo
783d5d2e14 [PJRT C API] Plumb plugin attributes from plugin to JAX python.
Also add a method for the plugin to return an xla_version plugin attribute.

Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins.

PiperOrigin-RevId: 621672421
2024-04-03 16:31:25 -07:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
jax authors
bfdf3c4b5d Merge pull request #20573 from mattjj:make-jaxpr-effects-simplify
PiperOrigin-RevId: 621656942
2024-04-03 15:29:44 -07:00
Matthew Johnson
e682fa8fdd small simplification to asymptotic complexity of make_jaxpr_effects 2024-04-03 14:44:41 -07:00
carlosgmartin
f0314c70e8 Add jax.nn.mish. 2024-04-03 16:37:07 -04:00
Dinghua Li
026f309dcb Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels.
PiperOrigin-RevId: 621611672
2024-04-03 13:02:19 -07:00
jax authors
0624775f3a Merge pull request #20561 from superbobry:docs
PiperOrigin-RevId: 621608577
2024-04-03 12:52:40 -07:00
jax authors
2df89b2184 Merge pull request #20569 from jakevdp:fix-ks-test
PiperOrigin-RevId: 621608564
2024-04-03 12:42:05 -07:00