rocm_jax/jaxlib/gpu_linalg.py
Dan Foreman-Mackey f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00

54 lines
1.6 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from jaxlib import xla_client
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"
)
except ImportError:
_cuda_linalg = None
else:
break
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hip_linalg = importlib.import_module(
f"{rocm_module_name}._linalg", package="jaxlib"
)
except ImportError:
_hip_linalg = None
else:
break
if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"cu_lu_pivots_to_permutation")
if _hip_linalg:
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(
_name, _value, platform="ROCM", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"hip_lu_pivots_to_permutation")