Edit pycapsule docstring to provide a little bit more context

The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
This commit is contained in:
Dan Foreman-Mackey 2024-06-07 11:00:50 -04:00
parent 5fcd50b7fa
commit 1fa66590d1
3 changed files with 30 additions and 20 deletions

10
docs/jax.extend.ffi.rst Normal file
View File

@ -0,0 +1,10 @@
``jax.extend.ffi`` module
=========================
.. automodule:: jax.extend.ffi
.. autosummary::
:toctree: _autosummary
ffi_lowering
pycapsule

View File

@ -11,6 +11,7 @@ Modules
.. toctree::
:maxdepth: 1
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir
jax.extend.random

View File

@ -29,38 +29,37 @@ from jax._src.typing import DimSize
def pycapsule(funcptr):
"""Construct a PyCapsule out of a ctypes function pointer.
"""Wrap a ctypes function pointer in a PyCapsule.
A typical use for this is registering custom call targets with XLA:
The primary use of this function, and the reason why it lives with in the
``jax.extend.ffi`` submodule, is to wrap function calls from external
compiled libraries to be registered as XLA custom calls.
Example usage::
import ctypes
import jax
from jax.lib import xla_client
fooso = ctypes.cdll.LoadLibrary('./foo.so')
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
name="bar",
fn=jax.extend.ffi.pycapsule(fooso.bar),
fn=jax.extend.ffi.pycapsule(libfoo.bar),
platform=PLATFORM,
api_version=API_VERSION
)
Args:
funcptr: A function pointer loaded from a dynamic library using ``ctypes``.
Returns:
An opaque ``PyCapsule`` object wrapping ``funcptr``.
"""
# Note (https://docs.python.org/3/library/ctypes.html):
#
# ctypes.pythonapi
# An instance of PyDLL that exposes Python C API functions as attributes.
# Note that all these functions are assumed to return C int, which is of
# course not always the truth, so you have to assign the correct restype
# attribute to use these functions.
#
# Following this advice we annotate argument and return types of
# PyCapsule_New before we call it, based on the example here:
# https://stackoverflow.com/questions/65056619/converting-ctypes-c-void-p-to-pycapsule
PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
PyCapsule_New.restype = ctypes.py_object
PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
return PyCapsule_New(funcptr, None, PyCapsule_Destructor(0))
destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
builder = ctypes.pythonapi.PyCapsule_New
builder.restype = ctypes.py_object
builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor)
return builder(funcptr, None, destructor(0))
def include_dir() -> str: