Add an example demonstrating input-output aliasing with the FFI.

This commit is contained in:
Dan Foreman-Mackey 2024-11-21 13:08:50 -05:00
parent 00c363e15d
commit 62656b32db
3 changed files with 46 additions and 3 deletions

View File

@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.
ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());
// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
return registrations;
});
}

View File

@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
def counter(index):
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
def aliasing(x):
return jax.ffi.ffi_call(
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
input_output_aliases={0: 0})(x)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
import jax.numpy as jnp
@ -91,5 +91,16 @@ class CounterTests(jtu.JaxTestCase):
self.assertEqual(counter_fun(0)[1], 3)
class AliasingTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
def test_basic(self, x):
self.assertAllClose(cpu_examples.aliasing(x), x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())