[PJRT C API] Set the gpu plugin allocator related options.

PiperOrigin-RevId: 573111513
This commit is contained in:
Jieying Luo 2023-10-12 22:59:27 -07:00 committed by jax authors
parent 432506f1ae
commit 8a2c5d6a42

View File

@ -37,7 +37,28 @@ def initialize():
path,
__package__,
)
c_api = xb.register_plugin("cuda", priority=500, library_path=str(path))
# TODO(b/300099402): use the util method when it is ready.
options = {}
visible_devices = xb.CUDA_VISIBLE_DEVICES.value
if visible_devices != 'all':
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '')
preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '').lower()
if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
raise ValueError(
'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
'"bfc", or "cuda_async", got "%s"' % allocator
)
options['allocator'] = allocator
if memory_fraction:
options['memory_fraction'] = float(memory_fraction)
if preallocate:
options['preallocate'] = preallocate not in ('false', '0')
c_api = xb.register_plugin(
'cuda', priority=500, library_path=str(path), options=options
)
if cuda_plugin_extension:
xla_client.register_custom_call_handler(
"CUDA",