521 Commits

Author SHA1 Message Date
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Yash Katariya
23d360bded Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
2025-01-22 15:49:04 -08:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
jax authors
ee724565bf Merge pull request #25827 from gnecula:debug_info_2
PiperOrigin-RevId: 715407809
2025-01-14 09:12:37 -08:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Dougal
7d11d12bcd Mention expected tangent aval in error message, see #25517. 2025-01-14 08:51:12 -05:00
George Necula
b30df36d7d [better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.

Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
  difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
2025-01-14 13:49:18 +00:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Yash Katariya
deab6fbd80 Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too.
If you drop out of cpp cache, things are going to be slow anyways.

PiperOrigin-RevId: 700052522
2024-11-25 11:40:50 -08:00
jax authors
4363bb65d7 Merge pull request #24770 from jakevdp:extended-device-get
PiperOrigin-RevId: 695671688
2024-11-12 03:58:23 -08:00
Jake VanderPlas
58dee3ea33 jax.device_get: handle generic extended dtypes 2024-11-07 16:01:22 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Yash Katariya
fff33f90b2 Add compiler_options argument to jax.jit.
This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})`

PiperOrigin-RevId: 692283964
2024-11-01 14:01:19 -07:00
Yash Katariya
07858fa98d [sharding_in_types] Allow device_put to reshard inputs. device_put is a good choice for resharding since it already handles transpose correctly because it tracks the src sharding too.
PiperOrigin-RevId: 692274137
2024-11-01 13:25:08 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
jax authors
6dd1417d4a Merge pull request #24589 from jakevdp:device-get-key
PiperOrigin-RevId: 691154098
2024-10-29 14:03:18 -07:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
jax authors
47bacfab5e Merge pull request #24031 from garymm:garymm/vmap-error-msg
PiperOrigin-RevId: 689940504
2024-10-25 15:59:57 -07:00
Gary Miguel
9f7f08eccb Fix vmap error message when args passed by keyword
See the new test for a case that used to produce the wrong message.

Fixes: #24406
2024-10-25 15:17:03 -07:00
Brian Wieder
7db4b254e0 Clear extra_jit_context when exiting.
In for some reason, extra_jit_context was leaking when `pallas.core` no longer imported `pallas.pallas_call`, leading to leaking XLA Clients.

PiperOrigin-RevId: 689857071
2024-10-25 11:35:47 -07:00
Brian Wieder
633ac7eaa9 Clear caches on jax exit.
PiperOrigin-RevId: 682288160
2024-10-04 05:55:30 -07:00
Yash Katariya
1efca33187 Add donate and may_alias as an argument to device_put to allow for donation and aliasing.
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.

**Definition:**

* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.

* may_alias: If True, we may return the original buffer depending on the implementation.

**What problem are we solving?**

Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.

Adding `donate` allows users to avoid this pattern of code:

```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```

Now it can just be: `jax.device_put(inp, sharding, donate=True)`

**So what are the semantics of these 2 options?** Let's create a table:

| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |

`donate` is best effort for now until we fix the following things:

 * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.

 * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.

PiperOrigin-RevId: 681073828
2024-10-01 10:28:23 -07:00
Jake VanderPlas
cf51ee7ef0 Improve documentation for jax.jacobian 2024-09-26 05:09:47 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Parker Schuh
86fe463ad7 [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
After:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

Reverts b615266175effe4aefeb903620a19f3719a604da

PiperOrigin-RevId: 675746175
2024-09-17 16:11:28 -07:00
Sergei Lebedev
83bccdd289 sharding and weak_type parameters of ShapeDtypeStruct are now keyword-only
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.

PiperOrigin-RevId: 674324964
2024-09-13 09:24:38 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
jax authors
4957ab9a5e Clean up JAX backend for all backends to avoid dangling PyClient references.
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Yash Katariya
b615266175 Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01
PiperOrigin-RevId: 668969133
2024-08-29 09:43:19 -07:00
Yash Katariya
dd6f0e2e2e Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
2024-08-29 08:35:42 -07:00
Yash Katariya
82c9da020a Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
2024-08-20 16:18:58 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yue Sheng
09beb33226 Don't call api.clean_up when there is no default backend.
PiperOrigin-RevId: 658936536
2024-08-02 16:14:29 -07:00
Yue Sheng
88c8bacdca Add util.clear_all_caches to api.clear_backends and let api.clear_backends be called before process terminates on JAX CPU. This could make the PjRt CPU client object to be successfully destroyed during Python garbage collection.
PiperOrigin-RevId: 658843789
2024-08-02 11:08:48 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00