12616 Commits

Author SHA1 Message Date
Dinghua Li
9bb3f79be5 Avoid unnecessary fori_loop when calculating the block indices.
PiperOrigin-RevId: 621580324
2024-04-03 11:08:45 -07:00
jax authors
85cb16936c Merge pull request #20562 from olupton:bind-to-all
PiperOrigin-RevId: 621562851
2024-04-03 10:27:14 -07:00
jax authors
57ee6b7550 Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
2024-04-03 10:17:11 -07:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
Peter Hawkins
e2f47748e3 Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.

PiperOrigin-RevId: 621532796
2024-04-03 08:35:20 -07:00
Olli Lupton
2dd1b3d6c8 jax.distributed.initialize: specify bind address.
By default, the coordinator process listens on all interfaces.
2024-04-03 17:13:27 +02:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
jax authors
6b582f5977 Merge pull request #20552 from jakevdp:geomspace-complex
PiperOrigin-RevId: 621348396
2024-04-02 17:54:51 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Sharad Vikram
318ae8935a [Pallas TPU] Relax windowing restriction when lowering mapped grids
PiperOrigin-RevId: 621330022
2024-04-02 16:32:39 -07:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
jax authors
a54eb81d78 Merge pull request #20548 from jakevdp:uint-floordiv
PiperOrigin-RevId: 621292364
2024-04-02 14:14:55 -07:00
jax authors
e282bf57db Merge pull request #20536 from jakevdp:broadcast-to
PiperOrigin-RevId: 621287464
2024-04-02 13:59:12 -07:00
Jake VanderPlas
e99a3051ed jnp.floor_div: lower directly to div for unsigned int 2024-04-02 13:47:42 -07:00
jax authors
00489be23d Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
2024-04-02 08:56:37 -07:00
Jake VanderPlas
6de6983d59 jnp.broadcast_to: better error for invalid shape 2024-04-02 08:38:51 -07:00
Sergei Lebedev
2ee4c0f644 Added installation instructions to the error in _pallas_call_lowering
PiperOrigin-RevId: 621168804
2024-04-02 07:36:28 -07:00
George Necula
bff24c6d6f [callback] Improve caching effectiveness in presence of callbacks.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
2024-04-02 15:33:24 +02:00
jax authors
431015a14e Merge pull request #20383 from gnecula:doc_deprecation
PiperOrigin-RevId: 621153196
2024-04-02 06:19:58 -07:00
jax authors
b3fe9400fb Add round lowering rule.
PiperOrigin-RevId: 621110036
2024-04-02 02:55:34 -07:00
George Necula
84db689e39 A few more comments about how the deprecations work 2024-04-02 10:52:01 +02:00
George Necula
c491720ee1 Accelerate deprecation of jax.experimental.host_callback.id_print and stop_outfeed_receiver
`jax.experimental.host_callback` is deprecated and any API in that module will throw a DeprecationWarning. After this change the `id_print` and `stop_outfeed_receiver` will throw an `AttributeError` in internal code only.

Add a deprecation message for `barrier_wait`.

PiperOrigin-RevId: 621064083
2024-04-01 23:12:59 -07:00
Sergei Lebedev
16b3f00e42 Register GPU/TPU lowering for pallas_call_p lazily
Prior to this change we had to import jax.experimental.pallas.{gpu,tpu} in
jax.experimental.pallas only to get the lowering rules registered.

PiperOrigin-RevId: 620957622
2024-04-01 14:40:33 -07:00
jax authors
5a7e874339 Merge pull request #20524 from jakevdp:trapz
PiperOrigin-RevId: 620953434
2024-04-01 14:26:40 -07:00
Sergei Lebedev
c4f1a45205 Generalized the in_specs/out_specs types in PrefetchScalarGridSpec
PiperOrigin-RevId: 620949269
2024-04-01 14:11:55 -07:00
Yash Katariya
6557f680fd Rename SpecifiedLayout to DeviceLocalLayout
PiperOrigin-RevId: 620934348
2024-04-01 13:19:46 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Jieying Luo
68c674d106 [PJRT C API] Add a PJRT extension to register custom partitioner.
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.

PiperOrigin-RevId: 620357554
2024-03-29 15:40:26 -07:00
Trevor Gale
80c305da7b Add MegaBlox grouped matrix multiplication kernels for TPU.
PiperOrigin-RevId: 620331883
2024-03-29 13:50:49 -07:00
Jevin Jiang
7137b256af [Pallas] Fix a typo in error message of swap rule.
PiperOrigin-RevId: 620320550
2024-03-29 13:04:44 -07:00
George Necula
1012797127 Mark jax.experimental.host_callback.barrier_wait as deprecated.
The jax.experimental.host_callback module is deprecated and will be removed.

See https://github.com/google/jax/issues/20385.

The other API entry points have been marked as deprecated already, but barrier_wait was missed.

PiperOrigin-RevId: 620237286
2024-03-29 07:27:21 -07:00
Yash Katariya
84156f359f Add identity jit tests to go from pinned_host -> device and vice versa
PiperOrigin-RevId: 620114420
2024-03-28 18:20:32 -07:00
Dinghua Li
8bf3f47f02 Open source PagedAttention TPU kernel.
PiperOrigin-RevId: 620042536
2024-03-28 13:36:02 -07:00
Kanglan Tang
0a2e3cd3aa Move platform_mappings file to the root of github/jax
PiperOrigin-RevId: 620000024
2024-03-28 11:21:48 -07:00
jax authors
e03f1d4fd1 Allows for splitting the transpose of a scan into a scan and a map.
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.

If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.

Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
2024-03-28 10:54:50 -07:00
Yash Katariya
9e86aa5329 Add custom call on output along with S(5) because XLA requires the custom call to show the transfer.
Enable paramater streaming and weight offloading

PiperOrigin-RevId: 619711649
2024-03-27 17:07:36 -07:00
jax authors
66877c9987 Allow allow_spmd_propagation_to_output to be generated for outputs annotated with pjit.AUTO
PiperOrigin-RevId: 619608022
2024-03-27 12:04:03 -07:00
George Necula
c0c918aa8b [export] Increase minimum serialization version to 9.
Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024.

This change could break clients that set a specific JAX serialization version lower than 9.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 619588685
2024-03-27 11:06:23 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Matthew Johnson
fa9f02ba2f Reverts 0dde8f7f9607d09841ece7125dfc0773c3613fab
PiperOrigin-RevId: 619416732
2024-03-26 22:26:41 -07:00
jax authors
0dde8f7f96 Merge pull request #20445 from mattjj:scan-dont-traverse-body-jaxpr-in-lowering-2
PiperOrigin-RevId: 619387957
2024-03-26 19:52:11 -07:00
Matthew Johnson
9474b46012 [scan] don't traverse body jaxpr in lowering
This is an attempt to re-land #19819 aka cl/607570860 after a small number of
performance regressions.

As before, the main changes are:
 1. simplify the scan impl that we trace through to get the lowering, and
 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
    jaxpr we already have in hand.

The main motivation was (2), but (1) seems like a useful win too.

The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before).

The code in #19819 was simpler in that it avoided reshapes, concats, and
un-concats. But it caused at least one apparent performance regression (an XLA
bug?) and it was unrelated to the original goal of reducing tracing time. So
here we just land the trace time improvement.
2024-03-26 17:17:58 -07:00
Jieying Luo
4a9c8d1a0a Removed obsolete call to libtpu_module.configure_library_path().
PiperOrigin-RevId: 619340619
2024-03-26 16:06:41 -07:00
Yash Katariya
6e0c95585a Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation.
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.

The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.

PiperOrigin-RevId: 619292757
2024-03-26 13:28:45 -07:00
Sergei Lebedev
18c885d090 Removed double-printing of TTIR in Pallas GPU lowering
PiperOrigin-RevId: 619208376
2024-03-26 09:11:39 -07:00
George Necula
75db481299 [callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

Reverts 19e6156ccec0df7a900471df7840bc421da2898b

PiperOrigin-RevId: 619156176
2024-03-26 05:32:41 -07:00
George Karpenkov
33cf53c413 [XLA:GPU] Add option to return FDO profile as textproto.
PiperOrigin-RevId: 619105468
2024-03-26 01:35:27 -07:00
Sharad Vikram
f93c320dcc Enable extra args with input output aliasing
PiperOrigin-RevId: 619041158
2024-03-25 20:05:33 -07:00
jax authors
69980a27bb Use the information in allow_spmd_sharding_propagation_to_output and allow_spmd_sharding_propagation_to_parameters to determine what input and output tuple elements we are allowed to modfy the shardings of.
PiperOrigin-RevId: 619013275
2024-03-25 17:46:52 -07:00