Temporarily remove jit decorator on gather/scatter ops.

This commit is contained in:
Peter Hawkins 2019-09-16 13:57:07 -07:00
parent d691f81264
commit 45a02f39f0
2 changed files with 6 additions and 2 deletions

View File

@ -2407,7 +2407,9 @@ def _rewriting_take(arr, idx):
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
return _gather(arr, treedef, static_idx, dynamic_idx)
@partial(jit, static_argnums=(1, 2))
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx):
idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update

View File

@ -57,7 +57,9 @@ def _scatter_update(x, idx, y, scatter_op):
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
@partial(jit, static_argnums=(2, 3, 4))
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(2, 3, 4))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
y = lax.convert_element_type(y, lax.dtype(x))