216 Commits

Author SHA1 Message Date
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
Yash Katariya
355589f32b [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)

* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.

* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.

* scan only allows `xs` where the 0th dim is full replicated i.e. None.

PiperOrigin-RevId: 699014167
2024-11-21 20:13:23 -08:00
jax authors
e707edeafa Merge pull request #25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2024-11-21 09:57:55 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Yash Katariya
e904c177f7 Delete _normalized_spec from NamedSharding
PiperOrigin-RevId: 697779844
2024-11-18 15:35:38 -08:00
Yash Katariya
05716b58b0 [sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported.
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.

In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.

PiperOrigin-RevId: 696559375
2024-11-14 09:58:03 -08:00
Peter Hawkins
bc82203a5c Avoid using a contextmanager in Primitive.bind.
It's slightly faster to inline the context manager code into the implementation of bind.

PiperOrigin-RevId: 696142743
2024-11-13 08:20:36 -08:00
George Necula
c92507772c Cleanup more remnants of the jax.experimental.host_callback
Removes the outfeed rewriter mechanism and helper functions
`jaxpr_uses_outfeed`, which were used only by
`jax.experimental.host_callback`.
2024-11-12 03:27:10 -08:00
Dougal
763952a607 Fix buggy and confusing logic in the C++/pjit caching path.
When we have a cache miss in `_cpp_pjit` we want to compile the function and
store the executable. Previously we had a roundabout way of getting hold of that
executable. We'd trace the function to a jaxpr but we wouldn't lower and compile
it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion
would be peeled off and eventually we'd hit the `pjit_p` impl rule,
`_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that*
cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store
the executable in `most_recent_pjit_call_executable`. We'd eventually pop the
stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the
compiled object by looking up `most_recent_pjit_call_executable`.

There's room for bugs here if we hit one cache but not the other. For example,
if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we
won't compile the executable. Normally that would just mean that the `_cpp_pjit`
cache won't be populated. But if we've previously hit a function with the same
jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll
get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect
cache entry. The divergent cache behavior you need to trigger this started
happening with the "stackless" change because the tracing context became a
bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in
general have different tracing contexts.

With this change, we remove the whole `most_recent_pjit_call_executable` system.
Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the
executable directly rather than calling into `pjit_p.bind`. We do call into
`pjit_p.bind` if we're not in an eval context, but in that case we don't expect
to be able to populate the `_cpp_pjit` cache anyway.
2024-11-11 00:42:47 -05:00
Dougal Maclaurin
d352f4f245 Put the set of current spmd axis names in the axis env instead of spelunking
through the trace stack to find it.

PiperOrigin-RevId: 694710181
2024-11-08 18:15:26 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
Dougal Maclaurin
a45b0856c5 Relax leak checks under the jax_data_dependent_tracing_fallback flag.
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Yash Katariya
34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
242cb7cbc7 Fix the __repr__ of JaxprEqnContext
PiperOrigin-RevId: 675673700
2024-09-17 12:51:00 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
4c8bed9270 Don't add a sharding property to ShapedArray if sharding_in_types flag is not switched on.
PiperOrigin-RevId: 671475186
2024-09-05 12:48:10 -07:00
Yash Katariya
e3110c18f8 Remove dtype and weak_type from __slots__ of ShapedArray since it comes from UnShapedArray
PiperOrigin-RevId: 669447416
2024-08-30 14:37:29 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
jax authors
9bebf577dd Reverts 66a3f87a24016594794c2ee289826baed5e979a4
PiperOrigin-RevId: 665543713
2024-08-20 15:13:09 -07:00
Adam Paszke
66a3f87a24 Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32.

Reverts 24394a1b03f01138219013f4773104b834e498b7

PiperOrigin-RevId: 664742891
2024-08-19 04:28:44 -07:00
jax authors
24394a1b03 Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others.

Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting

PiperOrigin-RevId: 663752447
2024-08-16 09:20:57 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00