mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Fixed bug with dimension variables in unused args
JAX will aggressively drop module input arguments if they are not used. This can interfere with shape polymorphism, because it may result in dropping arguments from which we need to derive the values of shape variables. We fix this for now by disabling dropping arguments if there are dimension variables in the arguments shapes. A more precise technique would be to force keeping only of arguments that we need for deriving the dimension variables. However, that would be a much more involved change, for an uncertain benefit.
This commit is contained in:
parent
99facbab2a
commit
befb449f05
@ -2462,7 +2462,8 @@ def lower_sharding_computation(
|
||||
"Argument mapping: %s.",
|
||||
fun_name, global_in_avals, in_shardings)
|
||||
|
||||
if keep_unused:
|
||||
if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
for a in global_in_avals):
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
else:
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
|
@ -716,25 +716,6 @@ polymorphic_shapes = ["a, 2*a, b"]
|
||||
polymorphic_shapes = ["a * a, a"]
|
||||
```
|
||||
|
||||
Furthermore, when using the native serialization the inputs that are not needed in the computation
|
||||
are ignored, so the dimension variables must be derivable only from used inputs.
|
||||
In the following example, the `x_unused` is not part of the computation so its
|
||||
input shapes cannot be used for deriving the dimension variables, and you will
|
||||
get an error that `a` cannot be derived:
|
||||
|
||||
```python
|
||||
jax2tf.convert(lambda x_unused, y: y * 2.,
|
||||
polymorphic_shapes=["b, a", "b, _"])(x, y)
|
||||
```
|
||||
|
||||
An input is still considered unused if the computation uses only its shape.
|
||||
The code below gives the same error:
|
||||
|
||||
```python
|
||||
jax2tf.convert(lambda x_unused, y: y * x_unused.shape[0],
|
||||
polymorphic_shapes=["b, a", "b, _"])(x, y)
|
||||
```
|
||||
|
||||
## Known issues
|
||||
|
||||
`jax2tf` has been in use since 2020 and the vast majority of users encounter
|
||||
@ -954,7 +935,7 @@ g_tf_native_0 = tape.gradient(res, xs,
|
||||
|
||||
# Now with jax2tf.convert
|
||||
with tf.GradientTape() as tape:
|
||||
res = jax2tf.convert(fn, with_gradient=True)(*xs0
|
||||
res = jax2tf.convert(fn, with_gradient=True)(*xs)
|
||||
|
||||
g_jax2tf = tape.gradient(res, xs)
|
||||
# Returns: 0., 0., 2., None
|
||||
|
@ -1051,11 +1051,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
lambda x_unused, y: y * 2.0,
|
||||
arg_descriptors=[RandArg((4,), _f32), RandArg((3,), _f32)],
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_native_serialization else
|
||||
(ValueError,
|
||||
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
|
||||
polymorphic_shapes=["b1", "b2"])
|
||||
|
||||
# A polymorphic arg is not used, and the dimension var does appear
|
||||
# elsewhere but not as a trivial monomial.
|
||||
@ -1063,22 +1059,14 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
lambda x_unused, y: y * 2.0,
|
||||
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b1 * b1"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_native_serialization else
|
||||
(ValueError,
|
||||
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
|
||||
polymorphic_shapes=["b1", "b1 * b1"])
|
||||
|
||||
# It is not sufficient to just use the shape of an input; it is still unused
|
||||
check_shape_poly(self,
|
||||
lambda x_unused, y: y + x_unused.shape[0],
|
||||
arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)],
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_native_serialization else
|
||||
(KeyError,
|
||||
"Encountered dimension variable 'b1' that is not appearing in the shapes")))
|
||||
polymorphic_shapes=["b1", "b2"])
|
||||
|
||||
def test_with_custom_vjp(self):
|
||||
"""Shape-polymorphic custom VJP."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user