[shape_poly] Fixed bug with dimension variables in unused args

JAX will aggressively drop module input arguments if they are not
used. This can interfere with shape polymorphism, because it may
result in dropping arguments from which we need to derive the
values of shape variables.

We fix this for now by disabling dropping arguments if there
are dimension variables in the arguments shapes. A more precise
technique would be to force keeping only of arguments that we
need for deriving the dimension variables. However, that would be
a much more involved change, for an uncertain benefit.
This commit is contained in:
George Necula 2023-03-27 13:12:10 +02:00
parent 99facbab2a
commit befb449f05
3 changed files with 6 additions and 36 deletions

View File

@ -2462,7 +2462,8 @@ def lower_sharding_computation(
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
if keep_unused:
if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)

View File

@ -716,25 +716,6 @@ polymorphic_shapes = ["a, 2*a, b"]
polymorphic_shapes = ["a * a, a"]
```
Furthermore, when using the native serialization the inputs that are not needed in the computation
are ignored, so the dimension variables must be derivable only from used inputs.
In the following example, the `x_unused` is not part of the computation so its
input shapes cannot be used for deriving the dimension variables, and you will
get an error that `a` cannot be derived:
```python
jax2tf.convert(lambda x_unused, y: y * 2.,
polymorphic_shapes=["b, a", "b, _"])(x, y)
```
An input is still considered unused if the computation uses only its shape.
The code below gives the same error:
```python
jax2tf.convert(lambda x_unused, y: y * x_unused.shape[0],
polymorphic_shapes=["b, a", "b, _"])(x, y)
```
## Known issues
`jax2tf` has been in use since 2020 and the vast majority of users encounter
@ -954,7 +935,7 @@ g_tf_native_0 = tape.gradient(res, xs,
# Now with jax2tf.convert
with tf.GradientTape() as tape:
res = jax2tf.convert(fn, with_gradient=True)(*xs0
res = jax2tf.convert(fn, with_gradient=True)(*xs)
g_jax2tf = tape.gradient(res, xs)
# Returns: 0., 0., 2., None

View File

@ -1051,11 +1051,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
lambda x_unused, y: y * 2.0,
arg_descriptors=[RandArg((4,), _f32), RandArg((3,), _f32)],
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b2"],
expect_error=(
(None, None) if not config.jax2tf_default_native_serialization else
(ValueError,
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
polymorphic_shapes=["b1", "b2"])
# A polymorphic arg is not used, and the dimension var does appear
# elsewhere but not as a trivial monomial.
@ -1063,22 +1059,14 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
lambda x_unused, y: y * 2.0,
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b1 * b1"],
expect_error=(
(None, None) if not config.jax2tf_default_native_serialization else
(ValueError,
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
polymorphic_shapes=["b1", "b1 * b1"])
# It is not sufficient to just use the shape of an input; it is still unused
check_shape_poly(self,
lambda x_unused, y: y + x_unused.shape[0],
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b2"],
expect_error=(
(None, None) if not config.jax2tf_default_native_serialization else
(KeyError,
"Encountered dimension variable 'b1' that is not appearing in the shapes")))
polymorphic_shapes=["b1", "b2"])
def test_with_custom_vjp(self):
"""Shape-polymorphic custom VJP."""