The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
Instead of hoisting all float-type arrays during closure conversion,
only hoist JVPTracers (or tracers carrying such tracers
indirectly). Doing so better approximates the subset of
closure-captured values that participate in AD.
Co-authored-by: Matthew Johnson <mattjj@google.com>
At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.
PiperOrigin-RevId: 385562158
This fixes the case when the primal shape polymorphic function has
output shapes that are polynomials of the input shapes (not just
dimension variables).
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:
add 'inline' option to xla_call for jaxpr inlining
--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:
add 'inline' to jit docstring
--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:
new_sublevel in jax2tf
PiperOrigin-RevId: 371542778
Partly to make it more robust (e.g. we no longer need to implement
post_process_call), partly beacuse it is not really a call primitive
(it modifies the argument and return avals in a multiprocess mesh),
and partly as an experiment to see how difficult would it be to actually
make it more autodidax-like.
Overall, it seems like a mixed bag, maybe slightly positive. The thunks
are gone which is nice, but one has to be much more careful when dealing
with avals.
PiperOrigin-RevId: 371352737
This uses the recently added ability to modify the `BatchTrace` to add a new
`SPMDBatchTrace`, which additionally fills in `spmd_in_axes` and `spmd_out_axes`
of xmap primitives. Those fields are necessary, because XLA does not allow us to
emit partial `OpSharding` annotations, meaning that we have to track where the
positional axes of outer xmaps are inserted at the boundaries of inner xmaps.
Otherwise, XLA could misinterpret our intention and unnecessarily force
replication along the mesh axes used by the outer xmaps.
PiperOrigin-RevId: 369831571
Starting from this change, we start introducing xmapped names when
tracing the xmap jaxpr and eliminating them from avals when the values
are returned. This lets us enable two long-awaited checks:
1. Returning values that are mapped along more axes than `out_axes`
declare now results in a readable error, instead of an internal
vmap assertion.
2. We catch the resource-overlap error triggered by making two axes
mapped to the same resources coincide in a single value.
This makes the named shape joining function significantly simpler and
also has the added benefit of removing the requirement of having a
global total order on named axes, which we definitely shouldn't require!
After all, they might be classes defined by users who are unaware of
e.g. the classes we use internally.
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.
For example:
```
if not isinstance(d, masking.Poly):
if some_check(d):
raise ValueError
```
All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.
Perhaps the most eggregious was:
```
if type(i) is Poly:
# dummy index if i is polynomial, doesn't matter for shape inference
i = 0
```
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.
PiperOrigin-RevId: 364611475