Add swap as method to TransformedRef

PiperOrigin-RevId: 731541165
This commit is contained in:
Sharad Vikram 2025-02-26 19:18:30 -08:00 committed by jax authors
parent 1ecbac9702
commit b5fcffadd4

View File

@ -270,6 +270,10 @@ class TransformedRef:
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(self, idx, value)
def swap(self, value, idx=()):
from jax._src.state.primitives import ref_swap # pytype: disable=import-error
return ref_swap(self, idx, value)
def get(self, idx=()):
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(self, idx)
@ -355,6 +359,12 @@ class AbstractRef(core.AbstractValue):
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(tracer, idx)
@core.aval_method
@staticmethod
def swap(tracer, value, idx=()):
from jax._src.state.primitives import ref_swap # pytype: disable=import-error
return ref_swap(tracer, idx, value)
@core.aval_method
@staticmethod
def set(tracer, value, idx=()):