mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[shape_poly] Remove obsolete part of the shape polymorphism documentation
The section of division limitations is now obsolte, because JAX can represent division symbolically.
This commit is contained in:
parent
0134fa834c
commit
b3c405c2f5
@ -159,7 +159,7 @@ new shape:
|
||||
It is possible to convert dimension expressions explicitly
|
||||
to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`.
|
||||
The result of these operations can be used as regular JAX arrays,
|
||||
bug cannot be used anymore as dimensions in shapes.
|
||||
but cannot be used anymore as dimensions in shapes, e.g., in `reshape`:
|
||||
|
||||
```python
|
||||
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
|
||||
@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass
|
||||
These errors arise in a pre-processing step before the
|
||||
compilation.
|
||||
|
||||
### Division of symbolic dimensions is partially supported
|
||||
|
||||
JAX will attempt to simplify division and modulo operations,
|
||||
e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`.
|
||||
In particular, JAX will handle the cases when either (a) there
|
||||
is no remainder, or (b) the divisor is a constant
|
||||
in which case there may be a constant remainder.
|
||||
|
||||
For example, the code below results in a division error when trying to
|
||||
compute the inferred dimension for a `reshape` operation:
|
||||
|
||||
```python
|
||||
>>> b, = export.symbolic_shape("b")
|
||||
>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))(
|
||||
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
|
||||
Traceback (most recent call last):
|
||||
jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1).
|
||||
The remainder mod(b, - 2) should be 0.
|
||||
|
||||
```
|
||||
|
||||
Note that the following will succeed:
|
||||
|
||||
```python
|
||||
>>> b, = export.symbolic_shape("b")
|
||||
>>> # We specify that the first dimension is a multiple of 4
|
||||
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
|
||||
... jax.ShapeDtypeStruct((4*b,), dtype=np.int32))
|
||||
>>> exp.out_avals
|
||||
(ShapedArray(int32[2,2*b]),)
|
||||
|
||||
>>> # We specify that some other dimension is even
|
||||
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
|
||||
... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32))
|
||||
>>> exp.out_avals
|
||||
(ShapedArray(int32[2,15*b]),)
|
||||
|
||||
```
|
||||
|
||||
(shape_poly_debugging)=
|
||||
## Debugging
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user