Change jax.core.DropVar to be a non-singleton.

Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
This commit is contained in:
Peter Hawkins 2021-10-18 15:02:26 -07:00 committed by jax authors
parent 6c833a16a1
commit 48bbdbc890
8 changed files with 21 additions and 24 deletions

View File

@ -321,7 +321,7 @@ For example, here is an example fori loop
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[16] = add a c
_:* _:* e:f32[16] = while[
_:i32[] _:i32[] e:f32[16] = while[
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
k:i32[] = add h 1
l:f32[16] = mul f 3.0

View File

@ -2467,7 +2467,7 @@ def make_jaxpr(fun: Callable,
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:* = sin b
_:f32[] = sin b
d:f32[] = cos b
e:f32[] = mul 1.0 d
f:f32[] = neg e

View File

@ -212,13 +212,9 @@ def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
count = -1
suffix = ''
def __init__(self): pass
@property
def aval(self): return abstract_unit
def __init__(self, aval: 'AbstractValue'):
super().__init__(-1, '', aval)
def __repr__(self): return '_'
dropvar = DropVar()
class Literal:
__slots__ = ["val", "hash"]
@ -1806,7 +1802,7 @@ class DuplicateAxisNameError(Exception):
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
# Var identity is load-bearing, so we can't have duplicates!
if v is unitvar: return v
if v is dropvar: return v
if isinstance(v, DropVar): return v
assert v not in var_map
if not hasattr(v.aval, 'named_shape'):
var_map[v] = v
@ -1941,7 +1937,7 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
def write(v: Var, a: AbstractValue) -> None:
typecheck_assert(v not in env, f"Variable '{v}' already bound")
if v is not dropvar:
if not isinstance(v, DropVar):
typecheck_assert(typecompat(v.aval, a),
f"Variable '{v}' inconsistently typed as {a}, "
f"bound as {v.aval}")

View File

@ -175,12 +175,12 @@ def eval_sparse(
return env[var]
def write_buffer(var: core.Var, a: Array) -> None:
if var is core.dropvar:
if isinstance(var, core.DropVar):
return
env[var] = ArgSpec(a.shape, spenv.push(a), None)
def write(var: core.Var, a: ArgSpec) -> None:
if var is core.dropvar:
if isinstance(var, core.DropVar):
return
assert a is not None
env[var] = a
@ -210,7 +210,7 @@ def eval_sparse(
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
out = []
for buf, outvar in safe_zip(out_bufs, eqn.outvars):
if outvar is core.dropvar:
if isinstance(outvar, core.DropVar):
out.append(None)
else:
out.append(ArgSpec(buf.shape, spenv.push(buf), None))

View File

@ -36,7 +36,7 @@ from .._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
as_hashable_function)
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
dropvar, ConcreteArray, raise_to_shaped, Var, Atom,
ConcreteArray, raise_to_shaped, Var, Atom,
JaxprEqn, Primitive)
from jax._src import source_info_util
from ..config import config
@ -587,8 +587,8 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [core.dropvar if t is None else cast(Var, getvar(t))
for t in out_tracers]
outvars = [core.DropVar(core.abstract_unit) if t is None
else cast(Var, getvar(t)) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
def tracers_to_jaxpr(
@ -1254,7 +1254,8 @@ def _inline_literals(jaxpr, constvals):
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
outvars = [var(v) if v in used else core.DropVar(v.aval)
for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]

View File

@ -53,7 +53,7 @@ def primitives_by_source(jaxpr: core.Jaxpr):
def primitives_by_shape(jaxpr: core.Jaxpr):
def shape_fmt(var):
return '*' if var is core.dropvar else var.aval.str_short()
return '*' if isinstance(var, core.DropVar) else var.aval.str_short()
def key(eqn):
return (eqn.primitive.name, ' '.join(map(shape_fmt, eqn.outvars)))
return histogram(jaxpr, key, ' :: '.join)
@ -79,7 +79,7 @@ def var_defs_and_refs(jaxpr: core.Jaxpr):
assert v is not core.unitvar
assert v not in defs, v
assert v not in refs, v
if v is not core.dropvar:
if not isinstance(v, core.DropVar):
defs[v] = eqn
refs[v] = []

View File

@ -450,7 +450,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
return y + 3
jaxpr = make_jaxpr(f)(1).jaxpr
assert jaxpr.eqns[0].outvars[0] is core.dropvar
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_loop(self):
@ -461,7 +461,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
return y + 1.
jaxpr = make_jaxpr(f)(1.).jaxpr
assert jaxpr.eqns[0].outvars[0] is core.dropvar
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_cond(self):
@ -473,7 +473,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
return y
jaxpr = make_jaxpr(f)(1.).jaxpr
assert jaxpr.eqns[-1].outvars[0] is core.dropvar
assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_undefined_eqn_invar(self):

View File

@ -930,9 +930,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
identity=True
transforms=()
] b
_:* = mul c 2.00
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
_:* = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
_:f32[] = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...