mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic. Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval. PiperOrigin-RevId: 404071001
This commit is contained in:
parent
6c833a16a1
commit
48bbdbc890
@ -321,7 +321,7 @@ For example, here is an example fori loop
|
||||
{ lambda ; a:f32[16] b:i32[]. let
|
||||
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
|
||||
d:f32[16] = add a c
|
||||
_:* _:* e:f32[16] = while[
|
||||
_:i32[] _:i32[] e:f32[16] = while[
|
||||
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
|
||||
k:i32[] = add h 1
|
||||
l:f32[16] = mul f 3.0
|
||||
|
@ -2467,7 +2467,7 @@ def make_jaxpr(fun: Callable,
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = cos a
|
||||
c:f32[] = sin a
|
||||
_:* = sin b
|
||||
_:f32[] = sin b
|
||||
d:f32[] = cos b
|
||||
e:f32[] = mul 1.0 d
|
||||
f:f32[] = neg e
|
||||
|
12
jax/core.py
12
jax/core.py
@ -212,13 +212,9 @@ def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
|
||||
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
|
||||
# treat it as a special case of one. Its `aval` is similarly inexact.
|
||||
class DropVar(Var):
|
||||
count = -1
|
||||
suffix = ''
|
||||
def __init__(self): pass
|
||||
@property
|
||||
def aval(self): return abstract_unit
|
||||
def __init__(self, aval: 'AbstractValue'):
|
||||
super().__init__(-1, '', aval)
|
||||
def __repr__(self): return '_'
|
||||
dropvar = DropVar()
|
||||
|
||||
class Literal:
|
||||
__slots__ = ["val", "hash"]
|
||||
@ -1806,7 +1802,7 @@ class DuplicateAxisNameError(Exception):
|
||||
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
|
||||
# Var identity is load-bearing, so we can't have duplicates!
|
||||
if v is unitvar: return v
|
||||
if v is dropvar: return v
|
||||
if isinstance(v, DropVar): return v
|
||||
assert v not in var_map
|
||||
if not hasattr(v.aval, 'named_shape'):
|
||||
var_map[v] = v
|
||||
@ -1941,7 +1937,7 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
|
||||
def write(v: Var, a: AbstractValue) -> None:
|
||||
typecheck_assert(v not in env, f"Variable '{v}' already bound")
|
||||
if v is not dropvar:
|
||||
if not isinstance(v, DropVar):
|
||||
typecheck_assert(typecompat(v.aval, a),
|
||||
f"Variable '{v}' inconsistently typed as {a}, "
|
||||
f"bound as {v.aval}")
|
||||
|
@ -175,12 +175,12 @@ def eval_sparse(
|
||||
return env[var]
|
||||
|
||||
def write_buffer(var: core.Var, a: Array) -> None:
|
||||
if var is core.dropvar:
|
||||
if isinstance(var, core.DropVar):
|
||||
return
|
||||
env[var] = ArgSpec(a.shape, spenv.push(a), None)
|
||||
|
||||
def write(var: core.Var, a: ArgSpec) -> None:
|
||||
if var is core.dropvar:
|
||||
if isinstance(var, core.DropVar):
|
||||
return
|
||||
assert a is not None
|
||||
env[var] = a
|
||||
@ -210,7 +210,7 @@ def eval_sparse(
|
||||
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
|
||||
out = []
|
||||
for buf, outvar in safe_zip(out_bufs, eqn.outvars):
|
||||
if outvar is core.dropvar:
|
||||
if isinstance(outvar, core.DropVar):
|
||||
out.append(None)
|
||||
else:
|
||||
out.append(ArgSpec(buf.shape, spenv.push(buf), None))
|
||||
|
@ -36,7 +36,7 @@ from .._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
as_hashable_function)
|
||||
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
|
||||
dropvar, ConcreteArray, raise_to_shaped, Var, Atom,
|
||||
ConcreteArray, raise_to_shaped, Var, Atom,
|
||||
JaxprEqn, Primitive)
|
||||
from jax._src import source_info_util
|
||||
from ..config import config
|
||||
@ -587,8 +587,8 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
|
||||
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
||||
invars = [getvar(t) for t in in_tracers]
|
||||
outvars = [core.dropvar if t is None else cast(Var, getvar(t))
|
||||
for t in out_tracers]
|
||||
outvars = [core.DropVar(core.abstract_unit) if t is None
|
||||
else cast(Var, getvar(t)) for t in out_tracers]
|
||||
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
||||
|
||||
def tracers_to_jaxpr(
|
||||
@ -1254,7 +1254,8 @@ def _inline_literals(jaxpr, constvals):
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
invars = [lit(v) or var(v) for v in eqn.invars]
|
||||
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
|
||||
outvars = [var(v) if v in used else core.DropVar(v.aval)
|
||||
for v in eqn.outvars]
|
||||
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
|
||||
eqn.source_info))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
|
@ -53,7 +53,7 @@ def primitives_by_source(jaxpr: core.Jaxpr):
|
||||
|
||||
def primitives_by_shape(jaxpr: core.Jaxpr):
|
||||
def shape_fmt(var):
|
||||
return '*' if var is core.dropvar else var.aval.str_short()
|
||||
return '*' if isinstance(var, core.DropVar) else var.aval.str_short()
|
||||
def key(eqn):
|
||||
return (eqn.primitive.name, ' '.join(map(shape_fmt, eqn.outvars)))
|
||||
return histogram(jaxpr, key, ' :: '.join)
|
||||
@ -79,7 +79,7 @@ def var_defs_and_refs(jaxpr: core.Jaxpr):
|
||||
assert v is not core.unitvar
|
||||
assert v not in defs, v
|
||||
assert v not in refs, v
|
||||
if v is not core.dropvar:
|
||||
if not isinstance(v, core.DropVar):
|
||||
defs[v] = eqn
|
||||
refs[v] = []
|
||||
|
||||
|
@ -450,7 +450,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
return y + 3
|
||||
|
||||
jaxpr = make_jaxpr(f)(1).jaxpr
|
||||
assert jaxpr.eqns[0].outvars[0] is core.dropvar
|
||||
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
def test_jaxpr_dropvar_from_loop(self):
|
||||
@ -461,7 +461,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
return y + 1.
|
||||
|
||||
jaxpr = make_jaxpr(f)(1.).jaxpr
|
||||
assert jaxpr.eqns[0].outvars[0] is core.dropvar
|
||||
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
def test_jaxpr_dropvar_from_cond(self):
|
||||
@ -473,7 +473,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
return y
|
||||
|
||||
jaxpr = make_jaxpr(f)(1.).jaxpr
|
||||
assert jaxpr.eqns[-1].outvars[0] is core.dropvar
|
||||
assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar)
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
def test_jaxpr_undefined_eqn_invar(self):
|
||||
|
@ -930,9 +930,9 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
identity=True
|
||||
transforms=()
|
||||
] b
|
||||
_:* = mul c 2.00
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
_:* = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
|
||||
_:f32[] = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
|
||||
e:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
|
Loading…
x
Reference in New Issue
Block a user