164 Commits

Author SHA1 Message Date
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
jax authors
2a2aa612be Merge pull request #22541 from mattjj:21343
PiperOrigin-RevId: 654522436
2024-07-21 11:37:39 -07:00
jax authors
ff36ea5de3 Merge pull request #21567 from mattjj:skip-invar-origin-msg-if-malformed
PiperOrigin-RevId: 654356735
2024-07-20 14:30:17 -07:00
Matthew Johnson
c5fd3b0ced skip _origin_msg invar debug info if invar_pos/arg_info is malformed
cf #20397, #20396
2024-07-20 17:22:49 +00:00
Matthew Johnson
83dfed1c02 make core.as_named_shape treat int like tuple[int]
fixes #21343
2024-07-20 17:14:38 +00:00
Sharad Vikram
f3c1cbc709 Add custom rules for str_eqn_compact
PiperOrigin-RevId: 651911281
2024-07-12 16:03:27 -07:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
jax authors
639892bb04 Merge pull request #22123 from mattjj:dynamic-trace-state-simplification
PiperOrigin-RevId: 647129667
2024-06-26 17:21:05 -07:00
Matthew Johnson
275ddad51d tweak dynamic trace state to only depend on level int, not MainTrace 2024-06-26 23:42:49 +00:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
8564b55ee2 remove double dots from an error message 2024-06-26 00:04:42 +00:00
Sergei Lebedev
ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
972fd66525 Add the threefry_partitionable config to JaxprEqnContext to allow setting it inside jit.
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
2024-05-24 09:15:50 -07:00
Sergei Lebedev
15b974c90b Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
2024-05-24 08:24:17 -07:00
Yash Katariya
711190155d Initialize JaxprEqnContext only in new_jaxpr_eqn and new_eqn_recipe with the current active compute type if no ctx is specified.
PiperOrigin-RevId: 636309959
2024-05-22 15:16:58 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Yash Katariya
9c01fc5f0f Add tests for sharded host computations
PiperOrigin-RevId: 636038410
2024-05-21 22:43:18 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
6577f47b83 Make eqn.ctx context manager thread safe by creating eqn.ctx.manager.
PiperOrigin-RevId: 635057475
2024-05-18 08:46:18 -07:00
Yash Katariya
25aa13c46b Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.
PiperOrigin-RevId: 634943450
2024-05-17 18:59:49 -07:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Matthew Johnson
e4c76c97e2 [omnitracing] partially un-regress dispatch time 2024-05-02 22:18:36 +00:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
jax authors
738612802e Merge pull request #21017 from mattjj:omnitracing-1
PiperOrigin-RevId: 629570166
2024-04-30 17:04:16 -07:00
Matthew Johnson
7af2b3cf72 [omnitracing] pop frames, see if anything downstream breaks 2024-04-30 14:52:27 -07:00
Chris Jones
20a8e2a6ec Allow replacing jaxpr debug_info with None.
The existing implementation of `Jaxpr.replace` would ignore the parameter `debug_info=None`.

PiperOrigin-RevId: 629421610
2024-04-30 08:31:39 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Jake VanderPlas
dc2d8c13d0 [key reuse] call key reuse logic directly in dispatch 2024-04-11 17:08:32 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Jake VanderPlas
5115b89538 Fix typos in comments 2024-04-08 15:16:39 -07:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
Matthew Johnson
fa9f02ba2f Reverts 0dde8f7f9607d09841ece7125dfc0773c3613fab
PiperOrigin-RevId: 619416732
2024-03-26 22:26:41 -07:00
Matthew Johnson
9474b46012 [scan] don't traverse body jaxpr in lowering
This is an attempt to re-land #19819 aka cl/607570860 after a small number of
performance regressions.

As before, the main changes are:
 1. simplify the scan impl that we trace through to get the lowering, and
 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
    jaxpr we already have in hand.

The main motivation was (2), but (1) seems like a useful win too.

The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before).

The code in #19819 was simpler in that it avoided reshapes, concats, and
un-concats. But it caused at least one apparent performance regression (an XLA
bug?) and it was unrelated to the original goal of reducing tracing time. So
here we just land the trace time improvement.
2024-03-26 17:17:58 -07:00
Yash Katariya
0b4634170e Don't report origin_msg if any execption is raised in self._origin_msg
PiperOrigin-RevId: 618237231
2024-03-22 11:23:46 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
jax authors
ce0d0c17c3 Merge pull request #20218 from mattjj:mutable-arrays-closure
PiperOrigin-RevId: 615463712
2024-03-13 10:23:23 -07:00
Matthew Johnson
649cd50681 [mutable-arrays] support closed-over mutable arrays in jit 2024-03-13 09:59:03 -07:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Jake VanderPlas
b349328d5d Remove some dead code 2024-03-06 11:30:48 -08:00