mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added jit transformations to generated functions. Fixed bug in comparing numpy arrays for equality.
This commit is contained in:
parent
0900db35b6
commit
30124b6da1
@ -153,5 +153,4 @@ class WrapHashably(object):
|
||||
return id(self.val)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val == other.val
|
||||
|
||||
return self.val is other.val
|
||||
|
@ -65,6 +65,7 @@ def gen_function(size, in_types):
|
||||
arg_types = [v.vartype for v in arg_vars]
|
||||
fun, out_types = gen_function(size / size_reduction_factor, arg_types)
|
||||
fun = partial(eval_fun, fun)
|
||||
fun = maybe_jit(fun, len(arg_types))
|
||||
else:
|
||||
arity = choice(list(primitive_generators))
|
||||
arg_vars = gen_sized_subset(cur_vars, arity)
|
||||
@ -96,6 +97,10 @@ def eval_fun(fun, *args):
|
||||
|
||||
return map(read, fun.out_vars)
|
||||
|
||||
def maybe_jit(f, num_args):
|
||||
static_argnums = thin(range(num_args), 0.5)
|
||||
return jit(f, static_argnums=static_argnums)
|
||||
|
||||
counter = it.count()
|
||||
def fresh_var(ty):
|
||||
return Var(next(counter), ty)
|
||||
|
Loading…
x
Reference in New Issue
Block a user