* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts
This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.
But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.
In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.
These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
Bug: 8367
Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
--
b40245e38d7837a7777735ad60f3b5b1ac2d499d by Sharad Vikram <sharad.vikram@gmail.com>:
Use `SourceInfo` named tuple to keep track of source information
PiperOrigin-RevId: 406293469
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
All control-flow primitives are `AxisPrimitive`s now, which means that we're doing
lots of those used-names traversals during dispatch and they can be expensive!
This adds caching to try to lower the cost of that change.
PiperOrigin-RevId: 395921887
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.
PiperOrigin-RevId: 395423006
I've also updated the docs for ``jax.ops`` to note that ``at[].set()``
is guaranteed to be performed in-place under JIT. Someone who knows XLA
well should double check that fact!
The string representation of ConcreteArray did not include the data type of the
wrapped value. This makes it harder to spot the reason for errors arising from
inconsistent values (issue #5364). This commit adds the data type to the string
representation of ConcreteArray.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
Instead of hoisting all float-type arrays during closure conversion,
only hoist JVPTracers (or tracers carrying such tracers
indirectly). Doing so better approximates the subset of
closure-captured values that participate in AD.
Co-authored-by: Matthew Johnson <mattjj@google.com>
At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.
PiperOrigin-RevId: 385562158
This fixes the case when the primal shape polymorphic function has
output shapes that are polynomials of the input shapes (not just
dimension variables).