193 Commits

Author SHA1 Message Date
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Yash Katariya
34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
242cb7cbc7 Fix the __repr__ of JaxprEqnContext
PiperOrigin-RevId: 675673700
2024-09-17 12:51:00 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
4c8bed9270 Don't add a sharding property to ShapedArray if sharding_in_types flag is not switched on.
PiperOrigin-RevId: 671475186
2024-09-05 12:48:10 -07:00
Yash Katariya
e3110c18f8 Remove dtype and weak_type from __slots__ of ShapedArray since it comes from UnShapedArray
PiperOrigin-RevId: 669447416
2024-08-30 14:37:29 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
jax authors
9bebf577dd Reverts 66a3f87a24016594794c2ee289826baed5e979a4
PiperOrigin-RevId: 665543713
2024-08-20 15:13:09 -07:00
Adam Paszke
66a3f87a24 Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32.

Reverts 24394a1b03f01138219013f4773104b834e498b7

PiperOrigin-RevId: 664742891
2024-08-19 04:28:44 -07:00
jax authors
24394a1b03 Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others.

Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting

PiperOrigin-RevId: 663752447
2024-08-16 09:20:57 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
jax authors
2a2aa612be Merge pull request #22541 from mattjj:21343
PiperOrigin-RevId: 654522436
2024-07-21 11:37:39 -07:00
jax authors
ff36ea5de3 Merge pull request #21567 from mattjj:skip-invar-origin-msg-if-malformed
PiperOrigin-RevId: 654356735
2024-07-20 14:30:17 -07:00
Matthew Johnson
c5fd3b0ced skip _origin_msg invar debug info if invar_pos/arg_info is malformed
cf #20397, #20396
2024-07-20 17:22:49 +00:00
Matthew Johnson
83dfed1c02 make core.as_named_shape treat int like tuple[int]
fixes #21343
2024-07-20 17:14:38 +00:00
Sharad Vikram
f3c1cbc709 Add custom rules for str_eqn_compact
PiperOrigin-RevId: 651911281
2024-07-12 16:03:27 -07:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
jax authors
639892bb04 Merge pull request #22123 from mattjj:dynamic-trace-state-simplification
PiperOrigin-RevId: 647129667
2024-06-26 17:21:05 -07:00
Matthew Johnson
275ddad51d tweak dynamic trace state to only depend on level int, not MainTrace 2024-06-26 23:42:49 +00:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
8564b55ee2 remove double dots from an error message 2024-06-26 00:04:42 +00:00
Sergei Lebedev
ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
972fd66525 Add the threefry_partitionable config to JaxprEqnContext to allow setting it inside jit.
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
2024-05-24 09:15:50 -07:00
Sergei Lebedev
15b974c90b Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
2024-05-24 08:24:17 -07:00