mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Set the gpu plugin allocator related options.
PiperOrigin-RevId: 573111513
This commit is contained in:
parent
432506f1ae
commit
8a2c5d6a42
@ -37,7 +37,28 @@ def initialize():
|
||||
path,
|
||||
__package__,
|
||||
)
|
||||
c_api = xb.register_plugin("cuda", priority=500, library_path=str(path))
|
||||
# TODO(b/300099402): use the util method when it is ready.
|
||||
options = {}
|
||||
visible_devices = xb.CUDA_VISIBLE_DEVICES.value
|
||||
if visible_devices != 'all':
|
||||
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
|
||||
|
||||
allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
|
||||
memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '')
|
||||
preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '').lower()
|
||||
if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
|
||||
raise ValueError(
|
||||
'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
|
||||
'"bfc", or "cuda_async", got "%s"' % allocator
|
||||
)
|
||||
options['allocator'] = allocator
|
||||
if memory_fraction:
|
||||
options['memory_fraction'] = float(memory_fraction)
|
||||
if preallocate:
|
||||
options['preallocate'] = preallocate not in ('false', '0')
|
||||
c_api = xb.register_plugin(
|
||||
'cuda', priority=500, library_path=str(path), options=options
|
||||
)
|
||||
if cuda_plugin_extension:
|
||||
xla_client.register_custom_call_handler(
|
||||
"CUDA",
|
||||
|
Loading…
x
Reference in New Issue
Block a user