fix XLA metadata for primitives with many args

This commit is contained in:
James Bradbury 2019-10-08 10:57:36 -07:00
parent 59ae24e874
commit bc0e79767b
3 changed files with 24 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from __future__ import print_function
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple, Counter, defaultdict
import itertools as it
from weakref import ref
import threading
import types
@ -87,6 +88,25 @@ JaxprEqn = namedtuple('JaxprEqn', ['eqn_id', 'invars', 'outvars', 'primitive',
JaxprEqn.__repr__ = JaxprEqn.__str__ = lambda eqn: str(pp_eqn(eqn))[:-1]
class Var(object):
def __init__(self, count, suffix):
self.count = count
self.suffix = suffix
def __repr__(self):
rem = self.count
s = ''
while True:
rem, i = rem // 26, rem % 26
s = chr(97 + i % 26) + s
if not rem:
break
return s + self.suffix
def gensym(suffix):
counter = it.count()
return lambda: Var(next(counter), suffix)
class Literal(object):
__slots__ = ["val", "hash"]

View File

@ -382,7 +382,7 @@ def eqn_tracer_to_var(var, eqn):
def tracers_to_jaxpr(in_tracers, out_tracers):
newvar = gensym('')
newvar = core.gensym('')
t_to_var = defaultdict(newvar)
var = lambda t: t_to_var[id(t)]
sorted_tracers = toposort(out_tracers)
@ -421,25 +421,6 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
return jaxpr, const_vals, env_vals
def gensym(suffix):
counter = it.count()
return lambda: Var(next(counter), suffix)
class Var(object):
def __init__(self, count, suffix):
self.count = count
self.suffix = suffix
def __repr__(self):
rem = self.count
s = ''
while True:
rem, i = rem // 26, rem % 26
s = chr(97 + i % 26) + s
if not rem:
break
return s + self.suffix
def eqn_parents(eqn):
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
return list(it.chain(eqn.invars, *subjaxpr_tracers))

View File

@ -143,11 +143,12 @@ def primitive_computation(prim, *xla_shapes, **params):
backend = params.get('backend', None)
new_params = {k: params[k] for k in params if k != 'backend'}
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
newvar = core.gensym('')
c.SetOpMetadata(xc.OpMetadata(
op_type=prim.name,
op_name=str(core.new_jaxpr_eqn(
[chr(ord('a') + i) for i in range(len(xla_shapes))],
[chr(ord('a') + len(xla_shapes))],
[newvar() for i in range(len(xla_shapes))],
[newvar()],
prim, (), params))
))
platform = xb.get_backend(backend).platform