mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add an example demonstrating input-output aliasing with the FFI.
This commit is contained in:
parent
00c363e15d
commit
62656b32db
@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
Counter, CounterImpl,
|
||||
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
// --------
|
||||
// Aliasing
|
||||
// --------
|
||||
//
|
||||
// This example demonstrates how input-output aliasing works. The handler
|
||||
// doesn't do anything except to check that the input and output pointers
|
||||
// address the same data.
|
||||
|
||||
ffi::Error AliasingImpl(ffi::AnyBuffer input,
|
||||
ffi::Result<ffi::AnyBuffer> output) {
|
||||
if (input.element_type() != output->element_type() ||
|
||||
input.element_count() != output->element_count()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The input and output data types and sizes must match.");
|
||||
}
|
||||
if (input.untyped_data() != output->untyped_data()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"When aliased, the input and output buffers should point to the same "
|
||||
"data.");
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
Aliasing, AliasingImpl,
|
||||
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());
|
||||
|
||||
// Boilerplate for exposing handlers to Python
|
||||
NB_MODULE(_cpu_examples, m) {
|
||||
m.def("registrations", []() {
|
||||
@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
|
||||
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
|
||||
registrations["dictionary_attr"] =
|
||||
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
|
||||
|
||||
registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
|
||||
|
||||
registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
|
||||
return registrations;
|
||||
});
|
||||
}
|
||||
|
@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
|
||||
def counter(index):
|
||||
return jax.ffi.ffi_call(
|
||||
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
|
||||
|
||||
|
||||
def aliasing(x):
|
||||
return jax.ffi.ffi_call(
|
||||
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
input_output_aliases={0: 0})(x)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -91,5 +91,16 @@ class CounterTests(jtu.JaxTestCase):
|
||||
self.assertEqual(counter_fun(0)[1], 3)
|
||||
|
||||
|
||||
class AliasingTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
|
||||
def test_basic(self, x):
|
||||
self.assertAllClose(cpu_examples.aliasing(x), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user