mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add swap as method to TransformedRef
PiperOrigin-RevId: 731541165
This commit is contained in:
parent
1ecbac9702
commit
b5fcffadd4
@ -270,6 +270,10 @@ class TransformedRef:
|
||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||
return ref_set(self, idx, value)
|
||||
|
||||
def swap(self, value, idx=()):
|
||||
from jax._src.state.primitives import ref_swap # pytype: disable=import-error
|
||||
return ref_swap(self, idx, value)
|
||||
|
||||
def get(self, idx=()):
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(self, idx)
|
||||
@ -355,6 +359,12 @@ class AbstractRef(core.AbstractValue):
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(tracer, idx)
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def swap(tracer, value, idx=()):
|
||||
from jax._src.state.primitives import ref_swap # pytype: disable=import-error
|
||||
return ref_swap(tracer, idx, value)
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def set(tracer, value, idx=()):
|
||||
|
Loading…
x
Reference in New Issue
Block a user