[shape_poly] Remove obsolete part of the shape polymorphism documentation

The section of division limitations is now obsolte, because JAX can represent
division symbolically.
This commit is contained in:
George Necula 2024-12-02 15:58:13 -08:00
parent 0134fa834c
commit b3c405c2f5

View File

@ -159,7 +159,7 @@ new shape:
It is possible to convert dimension expressions explicitly
to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`.
The result of these operations can be used as regular JAX arrays,
bug cannot be used anymore as dimensions in shapes.
but cannot be used anymore as dimensions in shapes, e.g., in `reshape`:
```python
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass
These errors arise in a pre-processing step before the
compilation.
### Division of symbolic dimensions is partially supported
JAX will attempt to simplify division and modulo operations,
e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`.
In particular, JAX will handle the cases when either (a) there
is no remainder, or (b) the divisor is a constant
in which case there may be a constant remainder.
For example, the code below results in a division error when trying to
compute the inferred dimension for a `reshape` operation:
```python
>>> b, = export.symbolic_shape("b")
>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
Traceback (most recent call last):
jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1).
The remainder mod(b, - 2) should be 0.
```
Note that the following will succeed:
```python
>>> b, = export.symbolic_shape("b")
>>> # We specify that the first dimension is a multiple of 4
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((4*b,), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,2*b]),)
>>> # We specify that some other dimension is even
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,15*b]),)
```
(shape_poly_debugging)=
## Debugging