[shape_poly] Add documentation for workaround with dimension parameters.

This commit is contained in:
George Necula 2024-06-17 11:54:16 +03:00
parent 4ea73bf787
commit b1a8c65883
2 changed files with 77 additions and 6 deletions

View File

@ -475,10 +475,82 @@ is unsound.
### Dimension variables must be solvable from the input shapes
When an exported object is invoked on inputs with concrete shapes
the dimension variables are derived from the input shapes.
This works only if the symbolic dimensions in the input shapes are linear.
See below an example of a failure:
Currently, the only way to pass the values of dimension variables
when an exported object is invoked is indirectly through the shapes
of the array arguments. E.g., the value of `b` can be inferred at the
call site from the shape of the first argument of type `f32[b]`.
This works well for most use cases, and
it mirrors the calling convention of JIT functions.
Sometimes you may want to export a function parameterized
by an integer values that determines some shapes in the program.
For example, we may
want to export the function `my_top_k` defined below,
parameterized by the
value of `k`, which determined the shape of the result.
The following attempt will lead to an error since the dimension
variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`:
```python
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10
... return lax.top_k(x, k)[0] # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))
>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])
>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])
>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
```
In the future, we may add an additional mechanism to pass the values of
dimension variables, besides implicitly through the input shapes.
Meanwhile, the workaround for the above use case is to replace the
function parameter `k` with an array of shape `(0, k)`, so that
`k` can be derived from the input shape of an array.
The first dimension is 0 to ensure that the whole array is empty
and there is no performance penalty when we call the exported function.
```python
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10]
... return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
... jax.ShapeDtypeStruct((0, k), dtype=np.int32),
... x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))
>>> exp.out_avals[0]
ShapedArray(int32[4,k])
>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
```
Another situation when you may get an error is when some dimension
variables do appear in the input shapes, but in a non-linear
expression that JAX cannot currently solve:
```python
>>> a, = export.symbolic_shape("a")
@ -537,7 +609,6 @@ In particular, JAX will handle the cases when either (a) there
is no remainder, or (b) the divisor is a constant
in which case there may be a constant remainder.
For example, the code below results in a division error when trying to
compute the inferred dimension for a `reshape` operation:

View File

@ -215,7 +215,7 @@ class _DimFactor:
return env[self.var]
except KeyError:
err_msg = (
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n"
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n"
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
raise KeyError(err_msg)
else: