Added jit transformations to generated functions. Fixed bug in comparing numpy arrays for equality.

This commit is contained in:
Dougal Maclaurin 2018-12-08 00:03:34 -05:00
parent 0900db35b6
commit 30124b6da1
2 changed files with 6 additions and 2 deletions

View File

@ -153,5 +153,4 @@ class WrapHashably(object):
return id(self.val)
def __eq__(self, other):
return self.val == other.val
return self.val is other.val

View File

@ -65,6 +65,7 @@ def gen_function(size, in_types):
arg_types = [v.vartype for v in arg_vars]
fun, out_types = gen_function(size / size_reduction_factor, arg_types)
fun = partial(eval_fun, fun)
fun = maybe_jit(fun, len(arg_types))
else:
arity = choice(list(primitive_generators))
arg_vars = gen_sized_subset(cur_vars, arity)
@ -96,6 +97,10 @@ def eval_fun(fun, *args):
return map(read, fun.out_vars)
def maybe_jit(f, num_args):
static_argnums = thin(range(num_args), 0.5)
return jit(f, static_argnums=static_argnums)
counter = it.count()
def fresh_var(ty):
return Var(next(counter), ty)