mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 03:06:04 +00:00
jit lax_numpy.roll (#2392)
This was making tracing slow for code with lots of rolls.
This commit is contained in:
parent
f3f0abb53e
commit
863576c5c1
@ -2672,8 +2672,8 @@ def msort(a):
|
||||
return sort(a, axis=0)
|
||||
|
||||
|
||||
@_wraps(onp.roll)
|
||||
def roll(a, shift, axis=None):
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _roll(a, shift, axis):
|
||||
a = asarray(a)
|
||||
a_shape = shape(a)
|
||||
if axis is None:
|
||||
@ -2696,6 +2696,11 @@ def roll(a, shift, axis=None):
|
||||
return a
|
||||
|
||||
|
||||
@_wraps(onp.roll)
|
||||
def roll(a, shift, axis=None):
|
||||
return _roll(a, shift, axis)
|
||||
|
||||
|
||||
@_wraps(onp.take)
|
||||
def take(a, indices, axis=None, out=None, mode=None):
|
||||
if out:
|
||||
|
Loading…
x
Reference in New Issue
Block a user