[shape_poly] Remove caching for the symbolic shape evaluator

The caching used for the shape_poly.CachingShapeEvaluator leads to
leaked tracer errors. This is because the `lru_cache` is attached
to the `CachingShapeEvaluator.evaluate` and persists for the
duration of the program. It is possible to reimplement the caching,
but in this case caching does not help much so we just remove it.
This commit is contained in:
George Necula 2024-11-09 11:10:16 +02:00
parent 87ce0cbb00
commit 45ae4dfb9e
2 changed files with 17 additions and 16 deletions

View File

@ -1297,9 +1297,10 @@ def _call_exported_abstract_eval(
# Must express the exported_dim_vars in terms of the shapes in in_avals.
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
exported.in_avals, args_kwargs_tree=exported.in_tree)
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in synth_dim_vars}
synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env)
synthetic_env: shape_poly.DimVarEnv = {
vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in synth_dim_vars}
synthetic_eval = shape_poly.ShapeEvaluator(synthetic_env)
# We discharge all the constraints statically. This results in much simpler
# composability (because we do not have to worry about the constraints of the
# Exported called recursively; we only need to worry about entry-point

View File

@ -1746,11 +1746,10 @@ def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]:
return sorted(dim_vars)
class CachingShapeEvaluator:
def __init__(self, **env):
class ShapeEvaluator:
def __init__(self, env: DimVarEnv):
self.env = env
@functools.lru_cache(128)
def evaluate(self, e: DimSize):
if core.is_constant_dim(e):
res = op.index(e) # type: ignore
@ -1769,7 +1768,7 @@ class ShapeConstraint:
# is formed by evaluating the DimSize and concatenating the sequence.
error_message_pieces: Sequence[str | DimSize]
def check_statically(self, eval: CachingShapeEvaluator) -> None:
def check_statically(self, eval: ShapeEvaluator) -> None:
"""Evaluates a constraint statically."""
left, right = eval.evaluate(self.left), eval.evaluate(self.right)
try:
@ -1785,7 +1784,7 @@ class ShapeConstraint:
if not ok:
raise self.make_error(eval)
def compute(self, eval: CachingShapeEvaluator) -> jax.Array | None:
def compute(self, eval: ShapeEvaluator) -> jax.Array | None:
"""Computes if the constraint is satisfied.
If the constraint can be resolved statically returns None
@ -1820,7 +1819,7 @@ class ShapeConstraint:
def error_message_and_inputs(
self,
eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]:
eval: ShapeEvaluator) -> tuple[str, Sequence[Any]]:
"""Forms the error_message and error message_inputs.
See shape_assertion.
"""
@ -1849,7 +1848,7 @@ class ShapeConstraint:
return ("".join(error_message_strings),
error_message_inputs)
def make_error(self, eval: CachingShapeEvaluator) -> Exception:
def make_error(self, eval: ShapeEvaluator) -> Exception:
error_message, error_message_inputs = self.error_message_and_inputs(eval)
return ValueError(error_message.format(*error_message_inputs))
@ -1865,7 +1864,7 @@ class ShapeConstraints:
c = ShapeConstraint(comp, left, right, error_message_pieces)
self.constraints.append(c)
def check_statically(self, eval: CachingShapeEvaluator) -> None:
def check_statically(self, eval: ShapeEvaluator) -> None:
"""Evaluates all the constraints statically.
If the static checking of any constraint fails, raises ValueError.
@ -1873,7 +1872,7 @@ class ShapeConstraints:
for constraint in self.constraints:
constraint.check_statically(eval)
def shape_assertions(self, eval: CachingShapeEvaluator) -> None:
def shape_assertions(self, eval: ShapeEvaluator) -> None:
"""Computes the shape assertions for the set of constraints.
See jax_export.Exported docstring.
@ -2014,10 +2013,11 @@ def compute_dim_vars_from_arg_shapes(
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
# Replace the synthetic vars with the dynamic shape of the actual arg
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars}
synthetic_eval = CachingShapeEvaluator(**synthetic_env)
synthetic_env: DimVarEnv = {
vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars
}
synthetic_eval = ShapeEvaluator(synthetic_env)
shape_constraints.shape_assertions(synthetic_eval)
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
return tuple(dim_values)