The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.
Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.
Add testing for `sqrt`/`rsqrt` for more types.
[XLA:Python] Add cbrt to XLA:Python bindings.
PiperOrigin-RevId: 386316949
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.
Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
The main problem was that jnp.einsum uses opt_einsum.contract_path
to parse the specification string and compute the order or the
contractions. This function wants to compute the sizes of operands
and intermediate results, and will fail if some dimensions are
polymorphic.
The (partial) solution here is to replace the operands with
jax.ShapeDtypeStruct with a fixed size for all dimension variables,
then call opt_einsum.contract_path and use that result if there
is only one contraction. We abort if there are multiple contractions.
This behavior is clearly sound. If there were multiple contractions,
perhaps their order would be different with different dimension sizes.
--
8226dfc8a4974b4c8031ee267fa5327e778140ee by Nicholas Junge <nicholas.junge@web.de>:
Handle negative values for list-like sections in jnp.split
PiperOrigin-RevId: 376302305
Mark the index_update() etc. operators as deprecated in the documentation.
Add new .divide and .power operators. Fixes#2694.
Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.