230 Commits

Author SHA1 Message Date
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Matthew Johnson
42ac4ca357 ref errors 2024-12-18 07:46:14 +00:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
2518c6233e Raise rather than return error 2024-12-17 11:20:55 -08:00
jax authors
0fa541972e Merge pull request #25456 from jakevdp:xla-abstractify
PiperOrigin-RevId: 707175097
2024-12-17 11:13:18 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Yash Katariya
41f490aef4 [sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
2024-12-10 20:20:23 -08:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
Yash Katariya
355589f32b [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)

* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.

* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.

* scan only allows `xs` where the 0th dim is full replicated i.e. None.

PiperOrigin-RevId: 699014167
2024-11-21 20:13:23 -08:00
jax authors
e707edeafa Merge pull request #25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2024-11-21 09:57:55 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Yash Katariya
e904c177f7 Delete _normalized_spec from NamedSharding
PiperOrigin-RevId: 697779844
2024-11-18 15:35:38 -08:00
Yash Katariya
05716b58b0 [sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported.
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.

In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.

PiperOrigin-RevId: 696559375
2024-11-14 09:58:03 -08:00
Peter Hawkins
bc82203a5c Avoid using a contextmanager in Primitive.bind.
It's slightly faster to inline the context manager code into the implementation of bind.

PiperOrigin-RevId: 696142743
2024-11-13 08:20:36 -08:00
George Necula
c92507772c Cleanup more remnants of the jax.experimental.host_callback
Removes the outfeed rewriter mechanism and helper functions
`jaxpr_uses_outfeed`, which were used only by
`jax.experimental.host_callback`.
2024-11-12 03:27:10 -08:00
Dougal
763952a607 Fix buggy and confusing logic in the C++/pjit caching path.
When we have a cache miss in `_cpp_pjit` we want to compile the function and
store the executable. Previously we had a roundabout way of getting hold of that
executable. We'd trace the function to a jaxpr but we wouldn't lower and compile
it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion
would be peeled off and eventually we'd hit the `pjit_p` impl rule,
`_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that*
cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store
the executable in `most_recent_pjit_call_executable`. We'd eventually pop the
stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the
compiled object by looking up `most_recent_pjit_call_executable`.

There's room for bugs here if we hit one cache but not the other. For example,
if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we
won't compile the executable. Normally that would just mean that the `_cpp_pjit`
cache won't be populated. But if we've previously hit a function with the same
jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll
get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect
cache entry. The divergent cache behavior you need to trigger this started
happening with the "stackless" change because the tracing context became a
bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in
general have different tracing contexts.

With this change, we remove the whole `most_recent_pjit_call_executable` system.
Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the
executable directly rather than calling into `pjit_p.bind`. We do call into
`pjit_p.bind` if we're not in an eval context, but in that case we don't expect
to be able to populate the `_cpp_pjit` cache anyway.
2024-11-11 00:42:47 -05:00
Dougal Maclaurin
d352f4f245 Put the set of current spmd axis names in the axis env instead of spelunking
through the trace stack to find it.

PiperOrigin-RevId: 694710181
2024-11-08 18:15:26 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
Dougal Maclaurin
a45b0856c5 Relax leak checks under the jax_data_dependent_tracing_fallback flag.
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Yash Katariya
34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00