jit lax_numpy.roll (#2392)

This was making tracing slow for code with lots of rolls.
This commit is contained in:
Stephan Hoyer 2020-03-09 13:21:30 -07:00 committed by GitHub
parent f3f0abb53e
commit 863576c5c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2672,8 +2672,8 @@ def msort(a):
return sort(a, axis=0)
@_wraps(onp.roll)
def roll(a, shift, axis=None):
@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a = asarray(a)
a_shape = shape(a)
if axis is None:
@ -2696,6 +2696,11 @@ def roll(a, shift, axis=None):
return a
@_wraps(onp.roll)
def roll(a, shift, axis=None):
return _roll(a, shift, axis)
@_wraps(onp.take)
def take(a, indices, axis=None, out=None, mode=None):
if out: