This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.
PiperOrigin-RevId: 613289252
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).
This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!
It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
Remove make_shaped_array since it has no more non-test users.
```
name old cpu/op new cpu/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
name old time/op new time/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
```
PiperOrigin-RevId: 491795793
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)
But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.
In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:
```python
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```
Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!
So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.