mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix XLA metadata for primitives with many args
This commit is contained in:
parent
59ae24e874
commit
bc0e79767b
20
jax/core.py
20
jax/core.py
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
from operator import attrgetter
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple, Counter, defaultdict
|
||||
import itertools as it
|
||||
from weakref import ref
|
||||
import threading
|
||||
import types
|
||||
@ -87,6 +88,25 @@ JaxprEqn = namedtuple('JaxprEqn', ['eqn_id', 'invars', 'outvars', 'primitive',
|
||||
|
||||
JaxprEqn.__repr__ = JaxprEqn.__str__ = lambda eqn: str(pp_eqn(eqn))[:-1]
|
||||
|
||||
class Var(object):
|
||||
def __init__(self, count, suffix):
|
||||
self.count = count
|
||||
self.suffix = suffix
|
||||
|
||||
def __repr__(self):
|
||||
rem = self.count
|
||||
s = ''
|
||||
while True:
|
||||
rem, i = rem // 26, rem % 26
|
||||
s = chr(97 + i % 26) + s
|
||||
if not rem:
|
||||
break
|
||||
return s + self.suffix
|
||||
|
||||
def gensym(suffix):
|
||||
counter = it.count()
|
||||
return lambda: Var(next(counter), suffix)
|
||||
|
||||
class Literal(object):
|
||||
__slots__ = ["val", "hash"]
|
||||
|
||||
|
@ -382,7 +382,7 @@ def eqn_tracer_to_var(var, eqn):
|
||||
|
||||
|
||||
def tracers_to_jaxpr(in_tracers, out_tracers):
|
||||
newvar = gensym('')
|
||||
newvar = core.gensym('')
|
||||
t_to_var = defaultdict(newvar)
|
||||
var = lambda t: t_to_var[id(t)]
|
||||
sorted_tracers = toposort(out_tracers)
|
||||
@ -421,25 +421,6 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
|
||||
def gensym(suffix):
|
||||
counter = it.count()
|
||||
return lambda: Var(next(counter), suffix)
|
||||
|
||||
class Var(object):
|
||||
def __init__(self, count, suffix):
|
||||
self.count = count
|
||||
self.suffix = suffix
|
||||
|
||||
def __repr__(self):
|
||||
rem = self.count
|
||||
s = ''
|
||||
while True:
|
||||
rem, i = rem // 26, rem % 26
|
||||
s = chr(97 + i % 26) + s
|
||||
if not rem:
|
||||
break
|
||||
return s + self.suffix
|
||||
|
||||
def eqn_parents(eqn):
|
||||
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
|
||||
return list(it.chain(eqn.invars, *subjaxpr_tracers))
|
||||
|
@ -143,11 +143,12 @@ def primitive_computation(prim, *xla_shapes, **params):
|
||||
backend = params.get('backend', None)
|
||||
new_params = {k: params[k] for k in params if k != 'backend'}
|
||||
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
|
||||
newvar = core.gensym('')
|
||||
c.SetOpMetadata(xc.OpMetadata(
|
||||
op_type=prim.name,
|
||||
op_name=str(core.new_jaxpr_eqn(
|
||||
[chr(ord('a') + i) for i in range(len(xla_shapes))],
|
||||
[chr(ord('a') + len(xla_shapes))],
|
||||
[newvar() for i in range(len(xla_shapes))],
|
||||
[newvar()],
|
||||
prim, (), params))
|
||||
))
|
||||
platform = xb.get_backend(backend).platform
|
||||
|
Loading…
x
Reference in New Issue
Block a user