182 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
242cb7cbc7 Fix the __repr__ of JaxprEqnContext
PiperOrigin-RevId: 675673700
2024-09-17 12:51:00 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
4c8bed9270 Don't add a sharding property to ShapedArray if sharding_in_types flag is not switched on.
PiperOrigin-RevId: 671475186
2024-09-05 12:48:10 -07:00
Yash Katariya
e3110c18f8 Remove dtype and weak_type from __slots__ of ShapedArray since it comes from UnShapedArray
PiperOrigin-RevId: 669447416
2024-08-30 14:37:29 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
jax authors
9bebf577dd Reverts 66a3f87a24016594794c2ee289826baed5e979a4
PiperOrigin-RevId: 665543713
2024-08-20 15:13:09 -07:00
Adam Paszke
66a3f87a24 Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32.

Reverts 24394a1b03f01138219013f4773104b834e498b7

PiperOrigin-RevId: 664742891
2024-08-19 04:28:44 -07:00
jax authors
24394a1b03 Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others.

Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting

PiperOrigin-RevId: 663752447
2024-08-16 09:20:57 -07:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
jax authors
2a2aa612be Merge pull request #22541 from mattjj:21343
PiperOrigin-RevId: 654522436
2024-07-21 11:37:39 -07:00
jax authors
ff36ea5de3 Merge pull request #21567 from mattjj:skip-invar-origin-msg-if-malformed
PiperOrigin-RevId: 654356735
2024-07-20 14:30:17 -07:00
Matthew Johnson
c5fd3b0ced skip _origin_msg invar debug info if invar_pos/arg_info is malformed
cf #20397, #20396
2024-07-20 17:22:49 +00:00
Matthew Johnson
83dfed1c02 make core.as_named_shape treat int like tuple[int]
fixes #21343
2024-07-20 17:14:38 +00:00
Sharad Vikram
f3c1cbc709 Add custom rules for str_eqn_compact
PiperOrigin-RevId: 651911281
2024-07-12 16:03:27 -07:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
jax authors
639892bb04 Merge pull request #22123 from mattjj:dynamic-trace-state-simplification
PiperOrigin-RevId: 647129667
2024-06-26 17:21:05 -07:00
Matthew Johnson
275ddad51d tweak dynamic trace state to only depend on level int, not MainTrace 2024-06-26 23:42:49 +00:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
8564b55ee2 remove double dots from an error message 2024-06-26 00:04:42 +00:00
Sergei Lebedev
ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
972fd66525 Add the threefry_partitionable config to JaxprEqnContext to allow setting it inside jit.
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
2024-05-24 09:15:50 -07:00
Sergei Lebedev
15b974c90b Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
2024-05-24 08:24:17 -07:00
Yash Katariya
711190155d Initialize JaxprEqnContext only in new_jaxpr_eqn and new_eqn_recipe with the current active compute type if no ctx is specified.
PiperOrigin-RevId: 636309959
2024-05-22 15:16:58 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Yash Katariya
9c01fc5f0f Add tests for sharded host computations
PiperOrigin-RevId: 636038410
2024-05-21 22:43:18 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
6577f47b83 Make eqn.ctx context manager thread safe by creating eqn.ctx.manager.
PiperOrigin-RevId: 635057475
2024-05-18 08:46:18 -07:00
Yash Katariya
25aa13c46b Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.
PiperOrigin-RevId: 634943450
2024-05-17 18:59:49 -07:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Matthew Johnson
e4c76c97e2 [omnitracing] partially un-regress dispatch time 2024-05-02 22:18:36 +00:00