A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.
PiperOrigin-RevId: 608594374
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
We make the following improvements:
* Add a `linear_combination` function to use for computing
linear combinations fo symbolic expressions. E.g, `a - b` used
to involve 2 operations: "-1 * b" and "a + -1*b".
* Change the representation of terms (_DimMon) from a dictionary
mapping factors (_DimAtom) to exponents, into a sorted tuple of
pairs (factor, exponent). This is worthwhile because in almost
all cases a term contains a single factor. Everywhere we used
`term.items()` now we use `term._factors`.
* Make the computation of `._hash` lazy. Previously, we used dictionaries
heavily for symbolic expressions and we always needed the hash value,
now we use dictionaries less.
* Replace `t.degree` with `t.is_constant`.
* Add `__slots__` to the representation of symbolic expressions
Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`
After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
The main changes are:
1. simplify the scan impl that we trace through to get the lowering, and
2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
jaxpr we already have in hand.
The main motivation was (2), but (1) seems like a useful win too.
The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before.
```
---------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------
device_put_big 419 ms 0.363 ms 10
```
PiperOrigin-RevId: 607512568
The args_consumed and forwarded_inputs context is not actually needed, because it can be checked
afterward. The only reason for this was to have more granular errors, but arguably it's better
to error on jaxpr input.
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).
Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728