Raise an error when attempting to mutate Jaxpr objects

This commit is contained in:
Jake VanderPlas 2023-01-23 09:37:58 -08:00
parent 13e875f8b8
commit a0eae5709f
5 changed files with 38 additions and 25 deletions

View File

@ -19,6 +19,8 @@ Remember to align the itemized text with the first line of an item within a list
dimension to JAX arrays. Operations involving symbolic dimensions and
`np.ndarray` now can raise errors when the result is used as a shape value
({jax-issue}`#14106`).
* jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}`14102`)
* Changes
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)

View File

@ -67,11 +67,17 @@ control_flow_allowed_effects: Set[Effect] = set()
class Jaxpr:
constvars: List[Var]
invars: List[Var]
outvars: List[Atom]
eqns: List[JaxprEqn]
effects: Effects
_constvars: List[Var]
_invars: List[Var]
_outvars: List[Atom]
_eqns: List[JaxprEqn]
_effects: Effects
constvars = property(lambda self: self._constvars)
invars = property(lambda self: self._invars)
outvars = property(lambda self: self._outvars)
eqns = property(lambda self: self._eqns)
effects = property(lambda self: self._effects)
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
@ -87,11 +93,11 @@ class Jaxpr:
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
"""
self.constvars = list(constvars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
self.effects = effects
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
@ -142,14 +148,17 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
class ClosedJaxpr:
jaxpr: Jaxpr
consts: List[Any]
_jaxpr: Jaxpr
_consts: List[Any]
jaxpr = property(lambda self: self._jaxpr)
consts = property(lambda self: self._consts)
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
self.jaxpr = jaxpr
self.consts = list(consts)
self._jaxpr = jaxpr
self._consts = list(consts)
@property
def in_avals(self):
@ -2423,10 +2432,10 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
var_map: Dict[Var, Var] = {}
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)

View File

@ -579,8 +579,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known)
if i is None]
jaxpr_known_ = jaxpr_known_.replace(
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
@ -1327,7 +1327,9 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
num_known_outs = len(carry_uk) - sum(carry_uk)
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs]
body_jaxpr_known = body_jaxpr_known.replace(
jaxpr=body_jaxpr_known.jaxpr.replace(
outvars=body_jaxpr_known.jaxpr.outvars[:num_known_outs]))
out_known = while_p.bind(
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)

View File

@ -1641,13 +1641,13 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
new_jaxpr_invars = (
new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] +
new_jaxpr_invars[nr_const_and_carry:-2])
new_jaxpr.jaxpr.invars = new_jaxpr_invars
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(invars=new_jaxpr_invars))
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
new_jaxpr_outvars = (
new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] +
new_jaxpr_outvars[num_carry:-2])
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(outvars=new_jaxpr_outvars))
eqns.append(
eqn.replace(
invars=new_invars,

View File

@ -411,7 +411,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
def test_check_jaxpr_cond_invalid(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr.invars = ()
cond.params['branches'][0].jaxpr._invars = ()
self.assertRaisesRegex(
core.JaxprTypeError,
'cond branch 0 takes 0 inputs, branch 1 takes 1',
@ -445,7 +445,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x), 100))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr.invars = ()
cond.params['branches'][0].jaxpr._invars = ()
msg = ''
try:
core.check_jaxpr(jaxpr)