mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #13603 from gnecula:native_unused
PiperOrigin-RevId: 494769977
This commit is contained in:
commit
23001ae782
@ -381,6 +381,10 @@ lowered with the batch dimension polymorphic and the remaining dimensions concre
|
||||
|
||||
It is reasonable to expect that there will be JAX programs for which there is a
|
||||
shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf.
|
||||
In general, you should expect that shape polymorphism can handle those programs for which
|
||||
all the intermediate shapes can be expressed as polynomials in the dimension variables
|
||||
appearing in the input shapes. In particular, this does not include programs whose
|
||||
intermediate shapes depend on the data.
|
||||
|
||||
### Details
|
||||
|
||||
@ -613,6 +617,38 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
|
||||
```
|
||||
|
||||
### Dimension variables must be solvable from the input shapes
|
||||
|
||||
`jax2tf` will generate code to derive the values of the dimension variables
|
||||
from the input shapes. This works only if dimension polynomials in the input shapes are linear.
|
||||
For example, the following `polymorphic_shapes` will result in errors:
|
||||
|
||||
```python
|
||||
polymorphic_shapes = ["a * a"] # Not a linear polynomial
|
||||
polymorphic_shapes = ["a + b"] # Too few equations to derive both `a` and `b`
|
||||
```
|
||||
|
||||
If you are using native lowering, the restrictions are stronger: every dimension
|
||||
variable must occur as the value of some dimension of some input, e.g.,
|
||||
the following will work:
|
||||
|
||||
```python
|
||||
polymorphic_shapes = ["a, 2*a, b"]
|
||||
polymorphic_shapes = ["a * a, a"]
|
||||
```
|
||||
|
||||
Furthermore, when using the native lowering the inputs that are not needed in the computation
|
||||
are ignored, so the dimension variables must be derivable only from used inputs.
|
||||
In the following example, the `x_unused` is not part of the computation so its
|
||||
input shapes cannot be used for deriving the dimension variables, and you will
|
||||
get an error that `a` cannot be derived:
|
||||
|
||||
```python
|
||||
jax2tf.convert(lambda x_unused, y: y * 2.,
|
||||
polymorphic_shapes=["b, a", "b, 2 * a"])(x, y)
|
||||
```
|
||||
|
||||
|
||||
## Known issues
|
||||
|
||||
`jax2tf` has been in use since 2020 and the vast majority of users encounter
|
||||
|
@ -595,46 +595,28 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
Special care must be taken in presence of shape polymorphism.
|
||||
"""
|
||||
# Look for shape polymorphism
|
||||
|
||||
# For each dimension variable, encode how to compute its value from the
|
||||
# shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1].
|
||||
# The order of the dimension variables must match the order of the first N
|
||||
# arguments of the lowered function.
|
||||
|
||||
# We now have two implementations for the native lowering. If --jax_dynamic_shapes
|
||||
# then we use JAX's in-progress support for native dynamic shapes. In that
|
||||
# case we assume that the dimension variables are listed in the order in which
|
||||
# they are encountered by scanning the arguments and their shapes in order.
|
||||
# If we don't use --jax_dynamic_shapes then the dimension variables are passed
|
||||
# in the alphabetical order of their names.
|
||||
abstracted_axes: Sequence[Dict[int, str]] = []
|
||||
dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec
|
||||
dim_vars_seen: List[str] = [] # the dim var names in order
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
one_abstract_axes = {}
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if not core.is_constant_dim(d):
|
||||
d_var = d.to_var()
|
||||
if d_var is None:
|
||||
raise ValueError(f"Only simple dimension variables supported: {aval.shape}")
|
||||
if not d_var in dim_vars_seen:
|
||||
dim_args_spec_dict[d_var] = f"{arg_idx}.{axis_idx}"
|
||||
dim_vars_seen.append(d_var)
|
||||
one_abstract_axes[axis_idx] = d_var
|
||||
abstracted_axes.append(one_abstract_axes)
|
||||
# then we use JAX's in-progress support for native dynamic shapes, and we pass
|
||||
# abstracted_axes to lowering functions. Otherwise, we just lower using
|
||||
# abstract values whose shapes may include polynomials (already in args_avals).
|
||||
if config.jax_dynamic_shapes:
|
||||
abstracted_axes: Sequence[Dict[int, str]] = []
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
one_abstract_axes = {}
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if not core.is_constant_dim(d):
|
||||
d_var = d.to_var()
|
||||
if d_var is None:
|
||||
raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}")
|
||||
one_abstract_axes[axis_idx] = d_var
|
||||
abstracted_axes.append(one_abstract_axes)
|
||||
|
||||
if any(abstracted_axes):
|
||||
if config.jax_dynamic_shapes:
|
||||
if any(abstracted_axes):
|
||||
abstracted_axes = tuple(abstracted_axes)
|
||||
# In the order we have seen them
|
||||
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_seen]
|
||||
else:
|
||||
abstracted_axes = None # type: ignore
|
||||
# In sorted order by name
|
||||
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_seen)]
|
||||
else:
|
||||
abstracted_axes = None # type: ignore
|
||||
dim_args_spec = []
|
||||
|
||||
arg_specs_jax = [
|
||||
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
|
||||
@ -647,7 +629,6 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
fun_jax_lower = jax.jit(fun_jax, backend=backend,
|
||||
keep_unused=True, # TODO: allow dropping unused
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
fun_jax_lower = fun_jax.lower
|
||||
@ -658,10 +639,6 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
else:
|
||||
mhlo_module = lowered.mhlo()
|
||||
xla_call_module_version = 1
|
||||
if logging.vlog_is_on(3):
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
logging.vlog(3, "XlaCallModule (version=%d)\n%s", xla_call_module_version,
|
||||
mhlo_module_text)
|
||||
|
||||
mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
|
||||
# Figure out the result types and shapes
|
||||
@ -685,6 +662,62 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
return jax_type
|
||||
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)
|
||||
|
||||
module_kept_var_idx = lowered.compile_args["kept_var_idx"]
|
||||
# We must compute the dim_args_spec: for each dimension variable, encode how
|
||||
# to compute its value from the shape of the explicit arguments. E.g., "2.1"
|
||||
# denotes args_tf[2].shape[1]. The order of the dimension variables must match
|
||||
# the order of the first N arguments of the lowered function.
|
||||
# If we use --jax_dynamic_shapes, the dimension variables are listed in the
|
||||
# order in which they are encountered by scanning the arguments and their
|
||||
# shapes in order. Otherwise, the dimension variables are passed in the
|
||||
# alphabetical order of their names.
|
||||
dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec
|
||||
dim_vars_order: List[str] = []
|
||||
all_dim_vars: Set[str] = set()
|
||||
current_kept_arg_idx = -1 # The index among the kept arguments
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
is_kept = arg_idx in module_kept_var_idx
|
||||
if is_kept:
|
||||
current_kept_arg_idx += 1
|
||||
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if not core.is_constant_dim(d):
|
||||
# We collect dimension variables even from dropped args
|
||||
all_dim_vars = all_dim_vars.union(d.get_vars())
|
||||
if not is_kept: continue
|
||||
d_var = d.to_var()
|
||||
# We can compute dim vars only from trivial polynomials
|
||||
if d_var is None: continue
|
||||
if not d_var in dim_args_spec_dict:
|
||||
dim_vars_order.append(d_var)
|
||||
dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}"
|
||||
|
||||
if all_dim_vars:
|
||||
dim_args_spec_set = set(dim_vars_order)
|
||||
if dim_args_spec_set != all_dim_vars:
|
||||
missing = all_dim_vars.difference(dim_args_spec_set)
|
||||
args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}"
|
||||
for arg_idx, aval in enumerate(args_avals)]
|
||||
raise ValueError(
|
||||
"The following dimension variables cannot be computed from the static "
|
||||
f"shapes of the kept lowered arguments: {missing}. These are the "
|
||||
"argument shapes:\n" +
|
||||
"\n".join(args_list) +
|
||||
"\n"
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
|
||||
|
||||
if config.jax_dynamic_shapes:
|
||||
# In the order we have seen them
|
||||
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order]
|
||||
else:
|
||||
# In sorted order by name
|
||||
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)]
|
||||
else:
|
||||
dim_args_spec = []
|
||||
|
||||
args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx]
|
||||
args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx]
|
||||
|
||||
# Apply the shardings on arguments and results for pjit. This is redundant
|
||||
# because the mhlo_module_text will already contain the shardings, but it
|
||||
# makes it easier for tools like the TPU inference converter to see the
|
||||
@ -694,6 +727,11 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
args_tf = tuple(
|
||||
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
|
||||
|
||||
if logging.vlog_is_on(3):
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
|
||||
xla_call_module_version, ", ".join(dim_args_spec),
|
||||
mhlo_module_text)
|
||||
res = tfxla.call_module(
|
||||
args_tf,
|
||||
version=xla_call_module_version,
|
||||
|
@ -886,5 +886,7 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
err_msg = (
|
||||
f"Cannot solve for values of dimension variables {unsolved_vars} from "
|
||||
f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} "
|
||||
"Dimension variables can be solved only from linear polynomials.")
|
||||
"Dimension variables can be solved only from linear polynomials.\n"
|
||||
"\n"
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
|
||||
raise ValueError(err_msg)
|
||||
|
@ -818,6 +818,28 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
jax2tf.convert(func)(2.) # No error
|
||||
|
||||
def test_jit_unused(self):
|
||||
def f_jax(x, y_unused):
|
||||
return x * np.float32(2.)
|
||||
x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
|
||||
res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused)
|
||||
self.assertAllClose(f_jax(x, None), res_tf)
|
||||
|
||||
def test_jit_unused_grad(self):
|
||||
def f_jax(x, y_unused):
|
||||
return x * np.float32(2.)
|
||||
|
||||
x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
|
||||
f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))
|
||||
xv, y_unused_v = tf.Variable(x), tf.Variable(y_unused)
|
||||
with tf.GradientTape() as tape:
|
||||
res_tf = f_tf(xv, y_unused_v)
|
||||
grad_tf_x, grad_tf_y = tape.gradient(res_tf, (xv, y_unused_v))
|
||||
|
||||
self.assertAllClose(f_jax(x, None), res_tf)
|
||||
self.assertAllClose(np.float32(2.), grad_tf_x)
|
||||
self.assertIsNone(grad_tf_y)
|
||||
|
||||
def test_nested_convert_error(self):
|
||||
def outer(y):
|
||||
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
|
||||
|
@ -20,7 +20,6 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
import operator
|
||||
import re
|
||||
|
||||
@ -722,6 +721,70 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([None, None])],
|
||||
polymorphic_shapes=["w, h"])
|
||||
|
||||
def test_non_trivial_polynomials(self):
|
||||
if config.jax_dynamic_shapes:
|
||||
raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials")
|
||||
# We can handle non-trivial polynomials in the input shape,
|
||||
# as long as all variables also occur in trivial polynoamials
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x, y: x + y.reshape((-1,)),
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])],
|
||||
polymorphic_shapes=["b * b", "b, b"])
|
||||
|
||||
def test_unused_args(self):
|
||||
# Tests with functions that do not use their inputs.
|
||||
|
||||
# First arg unused, not polymorphic
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=[None, "b"])
|
||||
|
||||
# Some args unused, not polymorphic
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]),
|
||||
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]),
|
||||
tf.TensorSpec([]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=[None, "b1", None, "b2"])
|
||||
|
||||
# A polymorphic arg is not used, but the dimension var appears
|
||||
# in a used arg also
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b", "b"])
|
||||
|
||||
# A polymorphic arg is not used, and the dimension var does not appear
|
||||
# elsewhere.
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The following dimension variables cannot be computed"):
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"])
|
||||
else:
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"])
|
||||
|
||||
# A polymorphic arg is not used, and the dimension var does appear
|
||||
# elsewhere but not as a trivial monomial.
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The following dimension variables cannot be computed"):
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b1 * b1"])
|
||||
else:
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x_unused, y: y * 2.0,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b1 * b1"])
|
||||
|
||||
|
||||
def test_with_custom_vjp(self):
|
||||
"""Shape-polymorphic custom VJP."""
|
||||
|
||||
@ -1065,6 +1128,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
||||
self.assertEqual(1, f_tf(x45))
|
||||
|
||||
x = np.ones((5,), dtype=np.float32)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot solve for values of dimension variables"):
|
||||
jax2tf.convert(lambda x: x, polymorphic_shapes=["a + b"])(x)
|
||||
|
||||
|
||||
class DimAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Dimension polynomials used as values in the JAX computation."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user