From b1a8c658831f5900c198f415289500e08495df1d Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 17 Jun 2024 11:54:16 +0300 Subject: [PATCH] [shape_poly] Add documentation for workaround with dimension parameters. --- docs/export/shape_poly.md | 81 ++++++++++++++++++++++++++++++++--- jax/_src/export/shape_poly.py | 2 +- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 3c28d38ab..8b07a3666 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -475,10 +475,82 @@ is unsound. ### Dimension variables must be solvable from the input shapes -When an exported object is invoked on inputs with concrete shapes -the dimension variables are derived from the input shapes. -This works only if the symbolic dimensions in the input shapes are linear. -See below an example of a failure: +Currently, the only way to pass the values of dimension variables +when an exported object is invoked is indirectly through the shapes +of the array arguments. E.g., the value of `b` can be inferred at the +call site from the shape of the first argument of type `f32[b]`. +This works well for most use cases, and +it mirrors the calling convention of JIT functions. + +Sometimes you may want to export a function parameterized +by an integer values that determines some shapes in the program. +For example, we may +want to export the function `my_top_k` defined below, +parameterized by the +value of `k`, which determined the shape of the result. +The following attempt will lead to an error since the dimension +variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: + +```python +>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10 +... return lax.top_k(x, k)[0] # : i32[4, 3] +>>> x = np.arange(40, dtype=np.int32).reshape((4, 10)) + +>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`. +>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x) +>>> exp_static_k.in_avals[0] +ShapedArray(int32[4,10]) + +>>> exp_static_k.out_avals[0] +ShapedArray(int32[4,3]) + +>>> # When calling the exported function we pass only the non-static arguments +>>> exp_static_k.call(x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +>>> # Now attempt to export with symbolic `k` so that we choose `k` after export. +>>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) +>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments + +``` + +In the future, we may add an additional mechanism to pass the values of +dimension variables, besides implicitly through the input shapes. +Meanwhile, the workaround for the above use case is to replace the +function parameter `k` with an array of shape `(0, k)`, so that +`k` can be derived from the input shape of an array. +The first dimension is 0 to ensure that the whole array is empty +and there is no performance penalty when we call the exported function. + +```python +>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10] +... return my_top_k(dimensions.shape[1], x) +>>> exp = export.export(jax.jit(my_top_k_with_dimensions))( +... jax.ShapeDtypeStruct((0, k), dtype=np.int32), +... x) +>>> exp.in_avals +(ShapedArray(int32[0,k]), ShapedArray(int32[4,10])) + +>>> exp.out_avals[0] +ShapedArray(int32[4,k]) + +>>> # When we invoke `exp` we must construct and pass an array of shape (0, k) +>>> exp.call(np.zeros((0, 3), dtype=np.int32), x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +``` + +Another situation when you may get an error is when some dimension +variables do appear in the input shapes, but in a non-linear +expression that JAX cannot currently solve: ```python >>> a, = export.symbolic_shape("a") @@ -537,7 +609,6 @@ In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. - For example, the code below results in a division error when trying to compute the inferred dimension for a `reshape` operation: diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 846a95ba0..43a827e8a 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -215,7 +215,7 @@ class _DimFactor: return env[self.var] except KeyError: err_msg = ( - f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" + f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise KeyError(err_msg) else: