126 Commits

Author SHA1 Message Date
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
Peter Hawkins
2d9cd86ae1 Disable two tests of GC behavior if there are multiple threads per process.
These don't seem to work reliably with multiple threads per process, even though the test is marked thread unsafe.
2025-01-31 09:14:49 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Matthew Johnson
8db862c02e fix memory leak in cond jaxpr tracig
fixes #12719
2024-07-23 23:57:02 +00:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Chris Jones
20a8e2a6ec Allow replacing jaxpr debug_info with None.
The existing implementation of `Jaxpr.replace` would ignore the parameter `debug_info=None`.

PiperOrigin-RevId: 629421610
2024-04-30 08:31:39 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Matthew Johnson
8c2f6b3e8c re-enable pjit forwarding optimization, add tests 2024-03-15 14:06:35 -07:00
Matthew Johnson
8a7c604aa7 disable optimization 2024-03-15 10:35:08 -07:00
Jake VanderPlas
7634708743 [key reuse] define KeyReuseError in jax.errors 2024-03-08 10:59:06 -08:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Peter Hawkins
67df647988 Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.

Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)

But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.

So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.

In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:

```
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = sin b
#     d:f32[] = sin c
#   in (d,) }

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```

Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!

So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```

PiperOrigin-RevId: 607664497
2024-02-16 05:57:12 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Sergei Lebedev
5d9c39f4b0 MAINT Use a generator expression with all() and any()
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jake VanderPlas
bfed3d862e Improve behavior of core.valid_jaxtype 2023-09-22 13:46:09 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Matthew Johnson
268456ef54 enable pjit recursive typechecking
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).

This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!

It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Roy Frostig
6b4de4f91c remove several more symbols from jax.core
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Zeynep Cankara
995ef40f68 [JAX] Improve error message when jit tracer passed to a shape.
Adds additional debugging message to the shape explaining why the value is a tracer.

Fixes #14279

PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
Yash Katariya
38f91bdaa5 Skip core tests which have nested pjits and DShapedArray.
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Peter Hawkins
6bda0d2863 Don't call dtypes.result_type() unnecessarily on the type of an array during abstractification.
Remove make_shaped_array since it has no more non-test users.

```
name        old cpu/op  new cpu/op  delta
device_put  69.4µs ± 6%  63.5µs ± 3%  -8.56%  (p=0.000 n=10+10)

name        old time/op             new time/op             delta
device_put  69.4µs ± 6%             63.5µs ± 3%  -8.56%        (p=0.000 n=10+10)
```

PiperOrigin-RevId: 491795793
2022-11-29 19:27:10 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
jax authors
5318df67fc Merge pull request #11151 from mattjj:input-fwd-residual-optimization
PiperOrigin-RevId: 455710192
2022-06-17 15:43:44 -07:00
Matthew Johnson
72a67906bf optimize grad-of-jit not to pass input-residuals as intermediate-residuals
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
Co-authored-by: Peter Hawkins <phawkins@google.com>
2022-06-17 15:24:28 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Matthew Johnson
05dda56019 add core.closed_call_p 2022-05-14 14:06:30 -07:00
Matthew Johnson
0b841cf35b make core_test.py pass with core.call 2022-05-11 15:45:40 -07:00
Matthew Johnson
bb56f40947 Internal change
PiperOrigin-RevId: 447549479
2022-05-09 13:26:30 -07:00
Matthew Johnson
705c07ae6d remove count attribute and total_ordering from core.Var
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)

But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.

In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:

```python
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```

Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!

So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
2022-05-09 09:31:23 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
477dfa6e46 [remove-units] don't use abstract_unit for dropvar avals 2022-04-28 22:51:41 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00