mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Temporarily remove jit
decorator on gather/scatter ops.
This commit is contained in:
parent
d691f81264
commit
45a02f39f0
@ -2407,7 +2407,9 @@ def _rewriting_take(arr, idx):
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
|
||||
return _gather(arr, treedef, static_idx, dynamic_idx)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
|
||||
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
|
||||
# @partial(jit, static_argnums=(1, 2))
|
||||
def _gather(arr, treedef, static_idx, dynamic_idx):
|
||||
idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
|
||||
|
@ -57,7 +57,9 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4))
|
||||
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
|
||||
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
|
||||
# @partial(jit, static_argnums=(2, 3, 4))
|
||||
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
|
||||
y = lax.convert_element_type(y, lax.dtype(x))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user