From b3c405c2f55b5c1d2cb42755bfcb8ed8a5e03b38 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 2 Dec 2024 15:58:13 -0800 Subject: [PATCH] [shape_poly] Remove obsolete part of the shape polymorphism documentation The section of division limitations is now obsolte, because JAX can represent division symbolically. --- docs/export/shape_poly.md | 41 +-------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 6ad7fb5c2..9254030a4 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -159,7 +159,7 @@ new shape: It is possible to convert dimension expressions explicitly to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. The result of these operations can be used as regular JAX arrays, -bug cannot be used anymore as dimensions in shapes. +but cannot be used anymore as dimensions in shapes, e.g., in `reshape`: ```python >>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( @@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass These errors arise in a pre-processing step before the compilation. -### Division of symbolic dimensions is partially supported - -JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. -In particular, JAX will handle the cases when either (a) there -is no remainder, or (b) the divisor is a constant -in which case there may be a constant remainder. - -For example, the code below results in a division error when trying to -compute the inferred dimension for a `reshape` operation: - -```python ->>> b, = export.symbolic_shape("b") ->>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b,), dtype=np.int32)) -Traceback (most recent call last): -jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). -The remainder mod(b, - 2) should be 0. - -``` - -Note that the following will succeed: - -```python ->>> b, = export.symbolic_shape("b") ->>> # We specify that the first dimension is a multiple of 4 ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,2*b]),) - ->>> # We specify that some other dimension is even ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,15*b]),) - -``` - (shape_poly_debugging)= ## Debugging