diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 6ad7fb5c2..9254030a4 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -159,7 +159,7 @@ new shape: It is possible to convert dimension expressions explicitly to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. The result of these operations can be used as regular JAX arrays, -bug cannot be used anymore as dimensions in shapes. +but cannot be used anymore as dimensions in shapes, e.g., in `reshape`: ```python >>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( @@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass These errors arise in a pre-processing step before the compilation. -### Division of symbolic dimensions is partially supported - -JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. -In particular, JAX will handle the cases when either (a) there -is no remainder, or (b) the divisor is a constant -in which case there may be a constant remainder. - -For example, the code below results in a division error when trying to -compute the inferred dimension for a `reshape` operation: - -```python ->>> b, = export.symbolic_shape("b") ->>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b,), dtype=np.int32)) -Traceback (most recent call last): -jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). -The remainder mod(b, - 2) should be 0. - -``` - -Note that the following will succeed: - -```python ->>> b, = export.symbolic_shape("b") ->>> # We specify that the first dimension is a multiple of 4 ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,2*b]),) - ->>> # We specify that some other dimension is even ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,15*b]),) - -``` - (shape_poly_debugging)= ## Debugging