Remove make_shaped_array since it has no more non-test users.
```
name old cpu/op new cpu/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
name old time/op new time/op delta
device_put 69.4µs ± 6% 63.5µs ± 3% -8.56% (p=0.000 n=10+10)
```
PiperOrigin-RevId: 491795793
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)
But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.
In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:
```python
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```
Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!
So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
f35014d had to revert part of #8955 because of a surprising downstream
breakage (relying on internal APIs). That breakage was isolated to how
_inline_literals handled invars.
The approach was a temporary one anyway: it relied on the fact that we
expect only to bind axis size variables at the top level and hence if we
didn't rename the input binders in _inline_literals we wouldn't need to
substitute new variables for any variables appearing in types. But a
more general approach would be to perform the necessary substitution
everywhere; after all, we might be inlining a literal into an axis size!
This commit takes the more general approach. It may fix the downstream
breakage automatically, just by virtue of being different; if not, I'll
figure out how to fix downstream.
The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.
This commit correctly threads through the context so as to provide
error messages with coherent variable names.
Final-style higher-order primitives, like call_p, xla_call_p (underlying
jit), xla_pmap_p (underlying pmap), and xmap_p (underlying xmap) have
slightly different bind signatures (while tracing) from their signatures
when they appear in jaxprs. In particular, their trace-time binds are
parameterized by a Python callable (or really a lu.WrappedFun)
representing the function to be applied, while in jaxpr eqns they are
parameterized by a jaxpr representing the same.
As a result, to round-trip from jaxpr to Python traceable, in
core.eval_jaxpr we have to convert from one parameter signature to the
other. (Basically we had to take the jaxpr and turn it into a Python
callable, via lu.wrap_init(partial(core.eval_jaxpr, call_jaxpr, ...)).)
However due to historical path dependence these conversion mechanisms
were all slightly distinct and kind of a mess. There was a case analysis
for call_jaxpr and map_jaxpr in core.eval_jaxpr_eqn (a helper function
created only because of this complexity), and there was a separate table
only used for the xmap rule.
In this PR we uniformized things! We basically only have a table (to
simplify core.eval_jaxpr), but instead of having it as a table we just
attached the rules to the different primitive classes (CallPrimitive,
MapPrimitive, and XmapPrimitive) to make things less error-prone (we
have a few different CallPrimitive instantiations, like call_p,
xla_call_p, named_call_p, and remat_call_p, and this way we don't have
to remember to populate the table separately for each).
This was actually a warmup simplification before we attempt to simplify
custom derivatives (to unify custom_jvp_call_p and
custom_jvp_call_jaxpr_p).
Co-authored-by: Roy Frostig <frostig@google.com>
partial_eval.py's _inline_literals) and skip new tests.
Some code seems to depend on whether we generate fresh invars (i.e. jaxpr input
binders) in that code. I'm not sure if it's a bug in the new JAX code or a bug in
the user code, but I'd like to un-break things while investigating.
PiperOrigin-RevId: 420296461
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.
The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.
PiperOrigin-RevId: 414704700
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
The string representation of ConcreteArray did not include the data type of the
wrapped value. This makes it harder to spot the reason for errors arising from
inconsistent values (issue #5364). This commit adds the data type to the string
representation of ConcreteArray.
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes
PiperOrigin-RevId: 355880632
In order to test that the typechecker identifies invalid jaxprs, some
tests modify jaxprs in place. This is typically not allowed, since
jaxprs are assumed immutable, and may be cached. As a workaround, this
change clears the relevant caches before every test. This ought to
prevent some order-dependent test failures.
rename and simplify TypedJaxpr -> ClosedJaxpr
This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
in_avals / out_avals no longer need to be constructed), making them
easier to work with;
* correspondingly rules out a class of errors (mismatches between
invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
but they're closed in that they are packaged with their constant
values).
This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)
Co-authored-by: George Necula <gcnecula@gmail.com>
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.