Ensure lax.scatter cache hits in op-by-op mode

This commit is contained in:
Jamie Townsend 2019-09-24 19:20:12 +02:00
parent ba5bcf2af5
commit f9b9146a92

View File

@ -785,6 +785,9 @@ def scatter_max(operand, scatter_indices, updates, dimension_numbers):
update_consts=consts, dimension_numbers=dimension_numbers,
updates_shape=updates.shape)
# Define this outside of scatter to ensure cache hits.
_scatter_reduction_computation = lambda x, y: y
def scatter(operand, scatter_indices, updates, dimension_numbers):
"""Scatter-update operator.
@ -809,7 +812,8 @@ def scatter(operand, scatter_indices, updates, dimension_numbers):
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _abstractify(_const(operand, 0)))
jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
_abstractify(_const(operand, 0)))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,