mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Add documentation for workaround with dimension parameters.
This commit is contained in:
parent
4ea73bf787
commit
b1a8c65883
@ -475,10 +475,82 @@ is unsound.
|
||||
|
||||
### Dimension variables must be solvable from the input shapes
|
||||
|
||||
When an exported object is invoked on inputs with concrete shapes
|
||||
the dimension variables are derived from the input shapes.
|
||||
This works only if the symbolic dimensions in the input shapes are linear.
|
||||
See below an example of a failure:
|
||||
Currently, the only way to pass the values of dimension variables
|
||||
when an exported object is invoked is indirectly through the shapes
|
||||
of the array arguments. E.g., the value of `b` can be inferred at the
|
||||
call site from the shape of the first argument of type `f32[b]`.
|
||||
This works well for most use cases, and
|
||||
it mirrors the calling convention of JIT functions.
|
||||
|
||||
Sometimes you may want to export a function parameterized
|
||||
by an integer values that determines some shapes in the program.
|
||||
For example, we may
|
||||
want to export the function `my_top_k` defined below,
|
||||
parameterized by the
|
||||
value of `k`, which determined the shape of the result.
|
||||
The following attempt will lead to an error since the dimension
|
||||
variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`:
|
||||
|
||||
```python
|
||||
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10
|
||||
... return lax.top_k(x, k)[0] # : i32[4, 3]
|
||||
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))
|
||||
|
||||
>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
|
||||
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
|
||||
>>> exp_static_k.in_avals[0]
|
||||
ShapedArray(int32[4,10])
|
||||
|
||||
>>> exp_static_k.out_avals[0]
|
||||
ShapedArray(int32[4,3])
|
||||
|
||||
>>> # When calling the exported function we pass only the non-static arguments
|
||||
>>> exp_static_k.call(x)
|
||||
Array([[ 9, 8, 7],
|
||||
[19, 18, 17],
|
||||
[29, 28, 27],
|
||||
[39, 38, 37]], dtype=int32)
|
||||
|
||||
>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
|
||||
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
|
||||
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
|
||||
|
||||
```
|
||||
|
||||
In the future, we may add an additional mechanism to pass the values of
|
||||
dimension variables, besides implicitly through the input shapes.
|
||||
Meanwhile, the workaround for the above use case is to replace the
|
||||
function parameter `k` with an array of shape `(0, k)`, so that
|
||||
`k` can be derived from the input shape of an array.
|
||||
The first dimension is 0 to ensure that the whole array is empty
|
||||
and there is no performance penalty when we call the exported function.
|
||||
|
||||
```python
|
||||
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10]
|
||||
... return my_top_k(dimensions.shape[1], x)
|
||||
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
|
||||
... jax.ShapeDtypeStruct((0, k), dtype=np.int32),
|
||||
... x)
|
||||
>>> exp.in_avals
|
||||
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))
|
||||
|
||||
>>> exp.out_avals[0]
|
||||
ShapedArray(int32[4,k])
|
||||
|
||||
>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
|
||||
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
|
||||
Array([[ 9, 8, 7],
|
||||
[19, 18, 17],
|
||||
[29, 28, 27],
|
||||
[39, 38, 37]], dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
Another situation when you may get an error is when some dimension
|
||||
variables do appear in the input shapes, but in a non-linear
|
||||
expression that JAX cannot currently solve:
|
||||
|
||||
```python
|
||||
>>> a, = export.symbolic_shape("a")
|
||||
@ -537,7 +609,6 @@ In particular, JAX will handle the cases when either (a) there
|
||||
is no remainder, or (b) the divisor is a constant
|
||||
in which case there may be a constant remainder.
|
||||
|
||||
|
||||
For example, the code below results in a division error when trying to
|
||||
compute the inferred dimension for a `reshape` operation:
|
||||
|
||||
|
@ -215,7 +215,7 @@ class _DimFactor:
|
||||
return env[self.var]
|
||||
except KeyError:
|
||||
err_msg = (
|
||||
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n"
|
||||
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n"
|
||||
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
|
||||
raise KeyError(err_msg)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user