In the past symbolic expressions were polynomials, consisting of sums
of monomials, which were products of atoms. Over time the language
of symbolic expressions has become richer. Now expressions
are sums of terms, which are products of factors.
Here we rename references to monomials to terms, and `_DimMon`
to `_DimTerm`. We also rename reference of atoms to factors,
and `_DimAtom` to `_DimFactor`.
At the same time we rename most of the methods of `_DimExpr`
to have a leading underscore, to indicate that they are
private methods.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
The main changes are:
1. simplify the scan impl that we trace through to get the lowering, and
2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
jaxpr we already have in hand.
The main motivation was (2), but (1) seems like a useful win too.
The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before.
The current implementation of jit inlining uses core.eval_jaxpr() and retraces the subjaxpr. This ends up performing abstract evaluation a second time. Instead, write a direct implementation of inlining that doesn't use the tracing machinery.
PiperOrigin-RevId: 607418006
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
This code triggers the relatively slow `Tracer.__getattr__` path on tracers, but as far as I can see a tracer can never have this attribute.
PiperOrigin-RevId: 606612790
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.
The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.
Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other
this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change
We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!
The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!
The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.
Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa
PiperOrigin-RevId: 561805840
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)