mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Raise an error when attempting to mutate Jaxpr objects
This commit is contained in:
parent
13e875f8b8
commit
a0eae5709f
@ -19,6 +19,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
dimension to JAX arrays. Operations involving symbolic dimensions and
|
||||
`np.ndarray` now can raise errors when the result is used as a shape value
|
||||
({jax-issue}`#14106`).
|
||||
* jaxpr objects now raise an error on attribute setting in order to avoid
|
||||
problematic mutations ({jax-issue}`14102`)
|
||||
|
||||
* Changes
|
||||
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
|
||||
|
@ -67,11 +67,17 @@ control_flow_allowed_effects: Set[Effect] = set()
|
||||
|
||||
|
||||
class Jaxpr:
|
||||
constvars: List[Var]
|
||||
invars: List[Var]
|
||||
outvars: List[Atom]
|
||||
eqns: List[JaxprEqn]
|
||||
effects: Effects
|
||||
_constvars: List[Var]
|
||||
_invars: List[Var]
|
||||
_outvars: List[Atom]
|
||||
_eqns: List[JaxprEqn]
|
||||
_effects: Effects
|
||||
|
||||
constvars = property(lambda self: self._constvars)
|
||||
invars = property(lambda self: self._invars)
|
||||
outvars = property(lambda self: self._outvars)
|
||||
eqns = property(lambda self: self._eqns)
|
||||
effects = property(lambda self: self._effects)
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
@ -87,11 +93,11 @@ class Jaxpr:
|
||||
effects: set of effects. The effects on a jaxpr are a superset of the
|
||||
union of the effects for each equation.
|
||||
"""
|
||||
self.constvars = list(constvars)
|
||||
self.invars = list(invars)
|
||||
self.outvars = list(outvars)
|
||||
self.eqns = list(eqns)
|
||||
self.effects = effects
|
||||
self._constvars = list(constvars)
|
||||
self._invars = list(invars)
|
||||
self._outvars = list(outvars)
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
|
||||
def __str__(self):
|
||||
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
|
||||
@ -142,14 +148,17 @@ def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
||||
|
||||
|
||||
class ClosedJaxpr:
|
||||
jaxpr: Jaxpr
|
||||
consts: List[Any]
|
||||
_jaxpr: Jaxpr
|
||||
_consts: List[Any]
|
||||
|
||||
jaxpr = property(lambda self: self._jaxpr)
|
||||
consts = property(lambda self: self._consts)
|
||||
|
||||
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
|
||||
assert len(consts) == len(jaxpr.constvars)
|
||||
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
|
||||
self.jaxpr = jaxpr
|
||||
self.consts = list(consts)
|
||||
self._jaxpr = jaxpr
|
||||
self._consts = list(consts)
|
||||
|
||||
@property
|
||||
def in_avals(self):
|
||||
@ -2423,10 +2432,10 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
|
||||
consts = jaxpr.consts
|
||||
jaxpr = jaxpr.jaxpr
|
||||
var_map: Dict[Var, Var] = {}
|
||||
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
|
||||
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
|
||||
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
|
||||
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
|
||||
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
|
||||
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
|
||||
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
|
||||
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
|
||||
if consts is not None:
|
||||
return ClosedJaxpr(new_jaxpr, consts)
|
||||
|
@ -579,8 +579,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
# Drop any extensive output we can instead get by forwarding an input.
|
||||
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
|
||||
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
|
||||
jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known)
|
||||
if i is None]
|
||||
jaxpr_known_ = jaxpr_known_.replace(
|
||||
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
|
||||
del jaxpr_known_
|
||||
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
|
||||
@ -1327,7 +1327,9 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
|
||||
num_known_outs = len(carry_uk) - sum(carry_uk)
|
||||
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
|
||||
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs]
|
||||
body_jaxpr_known = body_jaxpr_known.replace(
|
||||
jaxpr=body_jaxpr_known.jaxpr.replace(
|
||||
outvars=body_jaxpr_known.jaxpr.outvars[:num_known_outs]))
|
||||
out_known = while_p.bind(
|
||||
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
|
||||
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)
|
||||
|
@ -1641,13 +1641,13 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
new_jaxpr_invars = (
|
||||
new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] +
|
||||
new_jaxpr_invars[nr_const_and_carry:-2])
|
||||
new_jaxpr.jaxpr.invars = new_jaxpr_invars
|
||||
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(invars=new_jaxpr_invars))
|
||||
|
||||
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
|
||||
new_jaxpr_outvars = (
|
||||
new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] +
|
||||
new_jaxpr_outvars[num_carry:-2])
|
||||
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
||||
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(outvars=new_jaxpr_outvars))
|
||||
eqns.append(
|
||||
eqn.replace(
|
||||
invars=new_invars,
|
||||
|
@ -411,7 +411,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
def test_check_jaxpr_cond_invalid(self):
|
||||
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
|
||||
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
|
||||
cond.params['branches'][0].jaxpr.invars = ()
|
||||
cond.params['branches'][0].jaxpr._invars = ()
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
'cond branch 0 takes 0 inputs, branch 1 takes 1',
|
||||
@ -445,7 +445,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x), 100))(1.).jaxpr
|
||||
|
||||
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
|
||||
cond.params['branches'][0].jaxpr.invars = ()
|
||||
cond.params['branches'][0].jaxpr._invars = ()
|
||||
msg = ''
|
||||
try:
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user