mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[shape_poly] Remove caching for the symbolic shape evaluator
The caching used for the shape_poly.CachingShapeEvaluator leads to leaked tracer errors. This is because the `lru_cache` is attached to the `CachingShapeEvaluator.evaluate` and persists for the duration of the program. It is possible to reimplement the caching, but in this case caching does not help much so we just remove it.
This commit is contained in:
parent
87ce0cbb00
commit
45ae4dfb9e
@ -1297,9 +1297,10 @@ def _call_exported_abstract_eval(
|
||||
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
||||
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
|
||||
exported.in_avals, args_kwargs_tree=exported.in_tree)
|
||||
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env)
|
||||
synthetic_env: shape_poly.DimVarEnv = {
|
||||
vname: in_avals[arg_idx].shape[dim_idx]
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
synthetic_eval = shape_poly.ShapeEvaluator(synthetic_env)
|
||||
# We discharge all the constraints statically. This results in much simpler
|
||||
# composability (because we do not have to worry about the constraints of the
|
||||
# Exported called recursively; we only need to worry about entry-point
|
||||
|
@ -1746,11 +1746,10 @@ def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]:
|
||||
return sorted(dim_vars)
|
||||
|
||||
|
||||
class CachingShapeEvaluator:
|
||||
def __init__(self, **env):
|
||||
class ShapeEvaluator:
|
||||
def __init__(self, env: DimVarEnv):
|
||||
self.env = env
|
||||
|
||||
@functools.lru_cache(128)
|
||||
def evaluate(self, e: DimSize):
|
||||
if core.is_constant_dim(e):
|
||||
res = op.index(e) # type: ignore
|
||||
@ -1769,7 +1768,7 @@ class ShapeConstraint:
|
||||
# is formed by evaluating the DimSize and concatenating the sequence.
|
||||
error_message_pieces: Sequence[str | DimSize]
|
||||
|
||||
def check_statically(self, eval: CachingShapeEvaluator) -> None:
|
||||
def check_statically(self, eval: ShapeEvaluator) -> None:
|
||||
"""Evaluates a constraint statically."""
|
||||
left, right = eval.evaluate(self.left), eval.evaluate(self.right)
|
||||
try:
|
||||
@ -1785,7 +1784,7 @@ class ShapeConstraint:
|
||||
if not ok:
|
||||
raise self.make_error(eval)
|
||||
|
||||
def compute(self, eval: CachingShapeEvaluator) -> jax.Array | None:
|
||||
def compute(self, eval: ShapeEvaluator) -> jax.Array | None:
|
||||
"""Computes if the constraint is satisfied.
|
||||
|
||||
If the constraint can be resolved statically returns None
|
||||
@ -1820,7 +1819,7 @@ class ShapeConstraint:
|
||||
|
||||
def error_message_and_inputs(
|
||||
self,
|
||||
eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]:
|
||||
eval: ShapeEvaluator) -> tuple[str, Sequence[Any]]:
|
||||
"""Forms the error_message and error message_inputs.
|
||||
See shape_assertion.
|
||||
"""
|
||||
@ -1849,7 +1848,7 @@ class ShapeConstraint:
|
||||
return ("".join(error_message_strings),
|
||||
error_message_inputs)
|
||||
|
||||
def make_error(self, eval: CachingShapeEvaluator) -> Exception:
|
||||
def make_error(self, eval: ShapeEvaluator) -> Exception:
|
||||
error_message, error_message_inputs = self.error_message_and_inputs(eval)
|
||||
return ValueError(error_message.format(*error_message_inputs))
|
||||
|
||||
@ -1865,7 +1864,7 @@ class ShapeConstraints:
|
||||
c = ShapeConstraint(comp, left, right, error_message_pieces)
|
||||
self.constraints.append(c)
|
||||
|
||||
def check_statically(self, eval: CachingShapeEvaluator) -> None:
|
||||
def check_statically(self, eval: ShapeEvaluator) -> None:
|
||||
"""Evaluates all the constraints statically.
|
||||
|
||||
If the static checking of any constraint fails, raises ValueError.
|
||||
@ -1873,7 +1872,7 @@ class ShapeConstraints:
|
||||
for constraint in self.constraints:
|
||||
constraint.check_statically(eval)
|
||||
|
||||
def shape_assertions(self, eval: CachingShapeEvaluator) -> None:
|
||||
def shape_assertions(self, eval: ShapeEvaluator) -> None:
|
||||
"""Computes the shape assertions for the set of constraints.
|
||||
|
||||
See jax_export.Exported docstring.
|
||||
@ -2014,10 +2013,11 @@ def compute_dim_vars_from_arg_shapes(
|
||||
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
||||
|
||||
# Replace the synthetic vars with the dynamic shape of the actual arg
|
||||
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
|
||||
dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
synthetic_eval = CachingShapeEvaluator(**synthetic_env)
|
||||
synthetic_env: DimVarEnv = {
|
||||
vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars
|
||||
}
|
||||
synthetic_eval = ShapeEvaluator(synthetic_env)
|
||||
shape_constraints.shape_assertions(synthetic_eval)
|
||||
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
|
||||
return tuple(dim_values)
|
||||
|
Loading…
x
Reference in New Issue
Block a user