206 Commits

Author SHA1 Message Date
Dougal
763952a607 Fix buggy and confusing logic in the C++/pjit caching path.
When we have a cache miss in `_cpp_pjit` we want to compile the function and
store the executable. Previously we had a roundabout way of getting hold of that
executable. We'd trace the function to a jaxpr but we wouldn't lower and compile
it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion
would be peeled off and eventually we'd hit the `pjit_p` impl rule,
`_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that*
cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store
the executable in `most_recent_pjit_call_executable`. We'd eventually pop the
stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the
compiled object by looking up `most_recent_pjit_call_executable`.

There's room for bugs here if we hit one cache but not the other. For example,
if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we
won't compile the executable. Normally that would just mean that the `_cpp_pjit`
cache won't be populated. But if we've previously hit a function with the same
jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll
get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect
cache entry. The divergent cache behavior you need to trigger this started
happening with the "stackless" change because the tracing context became a
bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in
general have different tracing contexts.

With this change, we remove the whole `most_recent_pjit_call_executable` system.
Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the
executable directly rather than calling into `pjit_p.bind`. We do call into
`pjit_p.bind` if we're not in an eval context, but in that case we don't expect
to be able to populate the `_cpp_pjit` cache anyway.
2024-11-11 00:42:47 -05:00
Dougal Maclaurin
d352f4f245 Put the set of current spmd axis names in the axis env instead of spelunking
through the trace stack to find it.

PiperOrigin-RevId: 694710181
2024-11-08 18:15:26 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
Dougal Maclaurin
a45b0856c5 Relax leak checks under the jax_data_dependent_tracing_fallback flag.
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Yash Katariya
34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
242cb7cbc7 Fix the __repr__ of JaxprEqnContext
PiperOrigin-RevId: 675673700
2024-09-17 12:51:00 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
4c8bed9270 Don't add a sharding property to ShapedArray if sharding_in_types flag is not switched on.
PiperOrigin-RevId: 671475186
2024-09-05 12:48:10 -07:00
Yash Katariya
e3110c18f8 Remove dtype and weak_type from __slots__ of ShapedArray since it comes from UnShapedArray
PiperOrigin-RevId: 669447416
2024-08-30 14:37:29 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
jax authors
9bebf577dd Reverts 66a3f87a24016594794c2ee289826baed5e979a4
PiperOrigin-RevId: 665543713
2024-08-20 15:13:09 -07:00
Adam Paszke
66a3f87a24 Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32.

Reverts 24394a1b03f01138219013f4773104b834e498b7

PiperOrigin-RevId: 664742891
2024-08-19 04:28:44 -07:00
jax authors
24394a1b03 Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others.

Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting

PiperOrigin-RevId: 663752447
2024-08-16 09:20:57 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
jax authors
2a2aa612be Merge pull request #22541 from mattjj:21343
PiperOrigin-RevId: 654522436
2024-07-21 11:37:39 -07:00
jax authors
ff36ea5de3 Merge pull request #21567 from mattjj:skip-invar-origin-msg-if-malformed
PiperOrigin-RevId: 654356735
2024-07-20 14:30:17 -07:00
Matthew Johnson
c5fd3b0ced skip _origin_msg invar debug info if invar_pos/arg_info is malformed
cf #20397, #20396
2024-07-20 17:22:49 +00:00