146 Commits

Author SHA1 Message Date
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
972fd66525 Add the threefry_partitionable config to JaxprEqnContext to allow setting it inside jit.
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
2024-05-24 09:15:50 -07:00
Sergei Lebedev
15b974c90b Another attempt to land #20445
Reverts fa9f02ba2fd7e874edee0169773923e162ed0ea1

PiperOrigin-RevId: 636926775
2024-05-24 08:24:17 -07:00
Yash Katariya
711190155d Initialize JaxprEqnContext only in new_jaxpr_eqn and new_eqn_recipe with the current active compute type if no ctx is specified.
PiperOrigin-RevId: 636309959
2024-05-22 15:16:58 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Yash Katariya
9c01fc5f0f Add tests for sharded host computations
PiperOrigin-RevId: 636038410
2024-05-21 22:43:18 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
6577f47b83 Make eqn.ctx context manager thread safe by creating eqn.ctx.manager.
PiperOrigin-RevId: 635057475
2024-05-18 08:46:18 -07:00
Yash Katariya
25aa13c46b Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.
PiperOrigin-RevId: 634943450
2024-05-17 18:59:49 -07:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Matthew Johnson
e4c76c97e2 [omnitracing] partially un-regress dispatch time 2024-05-02 22:18:36 +00:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
jax authors
738612802e Merge pull request #21017 from mattjj:omnitracing-1
PiperOrigin-RevId: 629570166
2024-04-30 17:04:16 -07:00
Matthew Johnson
7af2b3cf72 [omnitracing] pop frames, see if anything downstream breaks 2024-04-30 14:52:27 -07:00
Chris Jones
20a8e2a6ec Allow replacing jaxpr debug_info with None.
The existing implementation of `Jaxpr.replace` would ignore the parameter `debug_info=None`.

PiperOrigin-RevId: 629421610
2024-04-30 08:31:39 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Jake VanderPlas
dc2d8c13d0 [key reuse] call key reuse logic directly in dispatch 2024-04-11 17:08:32 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Jake VanderPlas
5115b89538 Fix typos in comments 2024-04-08 15:16:39 -07:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
Matthew Johnson
fa9f02ba2f Reverts 0dde8f7f9607d09841ece7125dfc0773c3613fab
PiperOrigin-RevId: 619416732
2024-03-26 22:26:41 -07:00
Matthew Johnson
9474b46012 [scan] don't traverse body jaxpr in lowering
This is an attempt to re-land #19819 aka cl/607570860 after a small number of
performance regressions.

As before, the main changes are:
 1. simplify the scan impl that we trace through to get the lowering, and
 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
    jaxpr we already have in hand.

The main motivation was (2), but (1) seems like a useful win too.

The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before).

The code in #19819 was simpler in that it avoided reshapes, concats, and
un-concats. But it caused at least one apparent performance regression (an XLA
bug?) and it was unrelated to the original goal of reducing tracing time. So
here we just land the trace time improvement.
2024-03-26 17:17:58 -07:00
Yash Katariya
0b4634170e Don't report origin_msg if any execption is raised in self._origin_msg
PiperOrigin-RevId: 618237231
2024-03-22 11:23:46 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
jax authors
ce0d0c17c3 Merge pull request #20218 from mattjj:mutable-arrays-closure
PiperOrigin-RevId: 615463712
2024-03-13 10:23:23 -07:00
Matthew Johnson
649cd50681 [mutable-arrays] support closed-over mutable arrays in jit 2024-03-13 09:59:03 -07:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Jake VanderPlas
b349328d5d Remove some dead code 2024-03-06 11:30:48 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
Matthew Johnson
3a403f2a0e [mutable-arrays] move MutableArray, add eager, improve tests, fix bug
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
2024-03-01 15:03:23 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Peter Hawkins
9dee375dc6 Avoid use of operator.attrgetter in core.find_top_Trace, since it allocates and is less efficient than a lambda.
PiperOrigin-RevId: 609739110
2024-02-23 08:47:32 -08:00
George Necula
6705a96b56 [shape_poly] Cleaning up naming of terms and factors.
In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.

Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.

At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
2024-02-21 09:18:22 +01:00
Anselm Levskaya
772743e6a4 Internal change
Reverts 330afdc8bebe900d999202c4d59613e99cadb0ad

PiperOrigin-RevId: 607783139
2024-02-19 14:03:09 +00:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
Peter Hawkins
67df647988 Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.

Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)

But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.

So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.

In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:

```
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = sin b
#     d:f32[] = sin c
#   in (d,) }

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```

Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!

So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```

PiperOrigin-RevId: 607664497
2024-02-16 05:57:12 -08:00
jax authors
330afdc8be Merge pull request #19819 from mattjj:scan-dont-traverse-body-jaxpr-in-lowering
PiperOrigin-RevId: 607570860
2024-02-15 22:36:35 -08:00
Matthew Johnson
5ead7a6ef2 [scan] don't traverse body jaxpr in lowering
The main changes are:
 1. simplify the scan impl that we trace through to get the lowering, and
 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
    jaxpr we already have in hand.

The main motivation was (2), but (1) seems like a useful win too.

The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before.
2024-02-15 22:13:22 -08:00
Peter Hawkins
885e8a2311 Don't recompute abstract eval rules when inlining a jit jaxpr.
The current implementation of jit inlining uses core.eval_jaxpr() and retraces the subjaxpr. This ends up performing abstract evaluation a second time. Instead, write a direct implementation of inlining that doesn't use the tracing machinery.

PiperOrigin-RevId: 607418006
2024-02-15 12:28:48 -08:00
George Necula
18698a1f19 [shape_poly] Add support for jnp.split 2024-02-15 14:43:41 +01:00
Peter Hawkins
5833b0767b Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).

On my laptop:

```
In [18]: d = jnp.dtype(jnp.int32)

In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```

PiperOrigin-RevId: 606616884
2024-02-13 07:38:02 -08:00
Peter Hawkins
70edc40d4b Don't look for dimension_as_value on Tracers in core.full_raise.
This code triggers the relatively slow `Tracer.__getattr__` path on tracers, but as far as I can see a tracer can never have this attribute.

PiperOrigin-RevId: 606612790
2024-02-13 07:19:11 -08:00
George Necula
983bb32ae6 [shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.

See more details in the README.md changes.
2024-02-08 10:09:47 +01:00
Sergei Lebedev
078bb00fdb Replaced most usages of abc.ABC with util.StrictABC
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.

The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
George Necula
a1286d0021 [shape_poly] Improve core.max_dim and core.min_dim
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.

Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
2024-01-15 15:10:28 +02:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00