mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Ensure lax.scatter cache hits in op-by-op mode
This commit is contained in:
parent
ba5bcf2af5
commit
f9b9146a92
@ -785,6 +785,9 @@ def scatter_max(operand, scatter_indices, updates, dimension_numbers):
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates.shape)
|
||||
|
||||
# Define this outside of scatter to ensure cache hits.
|
||||
_scatter_reduction_computation = lambda x, y: y
|
||||
|
||||
def scatter(operand, scatter_indices, updates, dimension_numbers):
|
||||
"""Scatter-update operator.
|
||||
|
||||
@ -809,7 +812,8 @@ def scatter(operand, scatter_indices, updates, dimension_numbers):
|
||||
Returns:
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _abstractify(_const(operand, 0)))
|
||||
jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
|
||||
_abstractify(_const(operand, 0)))
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
|
Loading…
x
Reference in New Issue
Block a user