mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #271 from ROCm/ci-upstream-sync-142_1
CI: 03/11/25 upstream sync
This commit is contained in:
commit
6ee76a8a6d
6
.bazelrc
6
.bazelrc
@ -54,6 +54,12 @@ build:macos --apple_platform_type=macos
|
||||
build:macos --linkopt=-Wl,-undefined,dynamic_lookup
|
||||
build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup
|
||||
|
||||
# Use cc toolchains from apple_support for Apple builds.
|
||||
# https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup
|
||||
build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain
|
||||
build:macos --crosstool_top=@local_config_apple_cc//:toolchain
|
||||
build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain
|
||||
|
||||
# Windows has a relatively short command line limit, which JAX has begun to hit.
|
||||
# See https://docs.bazel.build/versions/main/windows.html
|
||||
build:windows --features=compiler_param_file
|
||||
|
2
.github/workflows/jax-array-api.yml
vendored
2
.github/workflows/jax-array-api.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
repository: data-apis/array-api-tests
|
||||
# TODO(jakevdp) update this to a stable release/tag when available.
|
||||
ref: 'd982a6245400295477f5da5afa1c4a2a5e641ea4' # Latest commit as of 2025-01-30
|
||||
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
|
||||
submodules: 'true'
|
||||
path: 'array-api-tests'
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
24
.github/workflows/pytest_cpu.yml
vendored
24
.github/workflows/pytest_cpu.yml
vendored
@ -29,11 +29,6 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -62,7 +57,6 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -88,7 +82,7 @@ jobs:
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download jaxlib wheel from GCS (non-Windows runs)
|
||||
- name: Download wheels from GCS (non-Windows runs)
|
||||
id: download-wheel-artifacts-nw
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
||||
@ -96,14 +90,10 @@ jobs:
|
||||
continue-on-error: true
|
||||
if: ${{ !contains(inputs.runner, 'windows-x86') }}
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
mkdir -p $(pwd)/dist
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Download jaxlib wheel from GCS (Windows runs)
|
||||
- name: Download wheels from GCS (Windows runs)
|
||||
id: download-wheel-artifacts-w
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in step below so that we can print a more
|
||||
@ -115,12 +105,8 @@ jobs:
|
||||
mkdir dist
|
||||
@REM Use `call` so that we can run sequential gsutil commands on Windows
|
||||
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
|
||||
|
||||
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if not "${{ inputs.install-jax-current-commit }}"=="1" (
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
|
||||
)
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
|
||||
run: |
|
||||
|
14
.github/workflows/pytest_cuda.yml
vendored
14
.github/workflows/pytest_cuda.yml
vendored
@ -34,11 +34,6 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -66,7 +61,6 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -86,7 +80,7 @@ jobs:
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download the wheel artifacts from GCS
|
||||
- name: Download wheels from GCS
|
||||
id: download-wheel-artifacts
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the next step so that we can print a more
|
||||
@ -94,14 +88,10 @@ jobs:
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts.outcome == 'failure'
|
||||
run: |
|
||||
|
5
.github/workflows/tsan-suppressions.txt
vendored
5
.github/workflows/tsan-suppressions.txt
vendored
@ -26,8 +26,6 @@ race_top:PyMember_GetOne
|
||||
# https://github.com/python/cpython/issues/129547
|
||||
race:type_get_annotations
|
||||
|
||||
# https://github.com/python/cpython/issues/130547
|
||||
race:split_keys_entry_added
|
||||
|
||||
# https://github.com/python/cpython/issues/129748
|
||||
race:mi_block_set_nextx
|
||||
@ -64,3 +62,6 @@ race:gemm_oncopy
|
||||
|
||||
# https://github.com/python/cpython/issues/130571
|
||||
# race:_PyObject_GetMethod
|
||||
|
||||
# https://github.com/python/cpython/issues/130547
|
||||
# race:split_keys_entry_added
|
||||
|
2
.github/workflows/tsan.yaml
vendored
2
.github/workflows/tsan.yaml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \
|
||||
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
|
||||
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
||||
libffi-dev liblzma-dev
|
||||
libffi-dev liblzma-dev file zip
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
path: jax
|
||||
|
16
.github/workflows/wheel_tests_continuous.yml
vendored
16
.github/workflows/wheel_tests_continuous.yml
vendored
@ -27,6 +27,16 @@ concurrency:
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
build-jax-artifact:
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
with:
|
||||
# Note that since jax is a pure python package, the runner OS and Python values do not
|
||||
# matter. In addition, cloning main XLA also has no effect.
|
||||
runner: "linux-x86-n2-16"
|
||||
artifact: "jax"
|
||||
upload_artifacts_to_gcs: true
|
||||
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
|
||||
build-jaxlib-artifact:
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
strategy:
|
||||
@ -66,7 +76,7 @@ jobs:
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: build-jaxlib-artifact
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact]
|
||||
uses: ./.github/workflows/pytest_cpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
@ -80,7 +90,6 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -88,7 +97,7 @@ jobs:
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
|
||||
uses: ./.github/workflows/pytest_cuda.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
@ -111,7 +120,6 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
|
@ -40,9 +40,6 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -61,7 +58,4 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
|
||||
load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_wheel",
|
||||
|
11
CHANGELOG.md
11
CHANGELOG.md
@ -23,6 +23,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
true, matching the current behavior. If set to false, JAX does not need to
|
||||
emit code clamping negative indices, which improves code size.
|
||||
|
||||
## jax 0.5.2 (Mar 4, 2025)
|
||||
|
||||
Patch release of 0.5.1
|
||||
|
||||
* Bug fixes
|
||||
* Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1
|
||||
|
||||
## jax 0.5.1 (Feb 24, 2025)
|
||||
|
||||
* New Features
|
||||
@ -54,6 +61,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
A downstream effect of this several other internal functions need debug
|
||||
info. This change does not affect public APIs.
|
||||
See https://github.com/jax-ml/jax/issues/26480 for more detail.
|
||||
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
|
||||
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
|
||||
|
||||
* Bug fixes
|
||||
* TPU runtime startup and shutdown time should be significantly improved on
|
||||
@ -169,8 +178,6 @@ to signify this.
|
||||
|
||||
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
|
||||
|
||||
## jax 0.4.37
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
|
||||
* Fix a bug that will throw `index out of range` error in
|
||||
|
12
WORKSPACE
12
WORKSPACE
@ -70,7 +70,7 @@ jax_python_wheel_repository(
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/py:python_wheel.bzl",
|
||||
"@xla//third_party/py:python_wheel.bzl",
|
||||
"python_wheel_version_suffix_repository",
|
||||
)
|
||||
python_wheel_version_suffix_repository(
|
||||
@ -78,7 +78,7 @@ python_wheel_version_suffix_repository(
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
|
||||
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
|
||||
"cuda_json_init_repository",
|
||||
)
|
||||
|
||||
@ -90,7 +90,7 @@ load(
|
||||
"CUDNN_REDISTRIBUTIONS",
|
||||
)
|
||||
load(
|
||||
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
|
||||
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
|
||||
"cuda_redist_init_repositories",
|
||||
"cudnn_redist_init_repository",
|
||||
)
|
||||
@ -104,21 +104,21 @@ cudnn_redist_init_repository(
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
|
||||
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
|
||||
"cuda_configure",
|
||||
)
|
||||
|
||||
cuda_configure(name = "local_config_cuda")
|
||||
|
||||
load(
|
||||
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
|
||||
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
|
||||
"nccl_redist_init_repository",
|
||||
)
|
||||
|
||||
nccl_redist_init_repository()
|
||||
|
||||
load(
|
||||
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
|
||||
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
|
||||
"nccl_configure",
|
||||
)
|
||||
|
||||
|
@ -475,7 +475,7 @@ def bench_pjit_check_aval_sharding(state):
|
||||
aval = jax.core.ShapedArray((8, 2), np.int32)
|
||||
|
||||
while state:
|
||||
pjit_check_aval_sharding([s] * 100, [aval] * 100, None, 'benchmark', False)
|
||||
pjit_check_aval_sharding([s] * 100, [aval] * 100, [''] * 100, 'benchmark', False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
|
@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = {
|
||||
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
||||
}
|
||||
|
||||
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
|
||||
# new wheel build rule fully, the build CLI will switch to the new wheel build
|
||||
# rule as the default.
|
||||
WHEEL_BUILD_TARGET_DICT_NEW = {
|
||||
"jax": "//:jax_wheel",
|
||||
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
|
||||
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
|
||||
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
|
||||
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
|
||||
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
|
||||
}
|
||||
|
||||
def add_global_arguments(parser: argparse.ArgumentParser):
|
||||
"""Adds all the global arguments that applies to all the CLI subcommands."""
|
||||
@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_new_wheel_build_rule",
|
||||
action="store_true",
|
||||
help=
|
||||
"""
|
||||
Whether to use the new wheel build rule. Temporary flag and will be
|
||||
removed once JAX migrates to the new wheel build rule fully.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--editable",
|
||||
action="store_true",
|
||||
@ -386,7 +407,10 @@ async def main():
|
||||
for option in args.bazel_startup_options:
|
||||
bazel_command_base.append(option)
|
||||
|
||||
bazel_command_base.append("run")
|
||||
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
|
||||
bazel_command_base.append("run")
|
||||
else:
|
||||
bazel_command_base.append("build")
|
||||
|
||||
if args.python_version:
|
||||
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
|
||||
@ -592,13 +616,19 @@ async def main():
|
||||
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
||||
|
||||
with open(".jax_configure.bazelrc", "w") as f:
|
||||
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
|
||||
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
|
||||
if not jax_configure_options:
|
||||
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
|
||||
sys.exit(1)
|
||||
f.write(jax_configure_options)
|
||||
logging.info("Bazel options written to .jax_configure.bazelrc")
|
||||
|
||||
if args.use_new_wheel_build_rule:
|
||||
logging.info("Using new wheel build rule")
|
||||
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
|
||||
else:
|
||||
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
|
||||
|
||||
if args.configure_only:
|
||||
logging.info("--configure_only is set so not running any Bazel commands.")
|
||||
else:
|
||||
@ -611,7 +641,7 @@ async def main():
|
||||
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
|
||||
wheel = "jax-" + wheel
|
||||
|
||||
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
|
||||
if wheel not in wheel_build_targets.keys():
|
||||
logging.error(
|
||||
"Incorrect wheel name provided, valid choices are jaxlib,"
|
||||
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
|
||||
@ -629,32 +659,33 @@ async def main():
|
||||
)
|
||||
|
||||
# Append the build target to the Bazel command.
|
||||
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
|
||||
build_target = wheel_build_targets[wheel]
|
||||
wheel_build_command.append(build_target)
|
||||
|
||||
wheel_build_command.append("--")
|
||||
if not args.use_new_wheel_build_rule:
|
||||
wheel_build_command.append("--")
|
||||
|
||||
if args.editable:
|
||||
logger.info("Building an editable build")
|
||||
output_path = os.path.join(output_path, wheel)
|
||||
wheel_build_command.append("--editable")
|
||||
if args.editable:
|
||||
logger.info("Building an editable build")
|
||||
output_path = os.path.join(output_path, wheel)
|
||||
wheel_build_command.append("--editable")
|
||||
|
||||
wheel_build_command.append(f'--output_path="{output_path}"')
|
||||
wheel_build_command.append(f"--cpu={target_cpu}")
|
||||
wheel_build_command.append(f'--output_path="{output_path}"')
|
||||
wheel_build_command.append(f"--cpu={target_cpu}")
|
||||
|
||||
if "cuda" in wheel:
|
||||
wheel_build_command.append("--enable-cuda=True")
|
||||
if args.cuda_version:
|
||||
cuda_major_version = args.cuda_version.split(".")[0]
|
||||
else:
|
||||
cuda_major_version = args.cuda_major_version
|
||||
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
||||
if "cuda" in wheel:
|
||||
wheel_build_command.append("--enable-cuda=True")
|
||||
if args.cuda_version:
|
||||
cuda_major_version = args.cuda_version.split(".")[0]
|
||||
else:
|
||||
cuda_major_version = args.cuda_major_version
|
||||
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
||||
|
||||
if "rocm" in wheel:
|
||||
wheel_build_command.append("--enable-rocm=True")
|
||||
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
||||
if "rocm" in wheel:
|
||||
wheel_build_command.append("--enable-rocm=True")
|
||||
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
||||
|
||||
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
||||
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
||||
|
||||
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
||||
# Exit with error if any wheel build fails.
|
||||
|
@ -18,5 +18,4 @@ ml_dtypes>=0.4.0
|
||||
opt_einsum
|
||||
zstandard
|
||||
etils[epath]
|
||||
# TODO(ybaturina): remove setuptools version
|
||||
setuptools<71.0.0
|
||||
setuptools
|
||||
|
@ -634,9 +634,9 @@ zstandard==0.22.0 \
|
||||
# via -r build/requirements.in
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==69.5.1 \
|
||||
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
|
||||
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
|
||||
setuptools==76.0.0 \
|
||||
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
|
||||
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
|
||||
# via
|
||||
# -r build/requirements.in
|
||||
# -r build/test-requirements.txt
|
||||
|
@ -623,9 +623,9 @@ zstandard==0.22.0 \
|
||||
# via -r build/requirements.in
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==69.5.1 \
|
||||
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
|
||||
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
|
||||
setuptools==76.0.0 \
|
||||
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
|
||||
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
|
||||
# via
|
||||
# -r build/requirements.in
|
||||
# -r build/test-requirements.txt
|
||||
|
@ -623,9 +623,9 @@ zstandard==0.22.0 \
|
||||
# via -r build/requirements.in
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==69.5.1 \
|
||||
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
|
||||
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
|
||||
setuptools==76.0.0 \
|
||||
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
|
||||
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
|
||||
# via
|
||||
# -r build/requirements.in
|
||||
# -r build/test-requirements.txt
|
||||
|
@ -747,9 +747,9 @@ zstandard==0.23.0 \
|
||||
# via -r build/requirements.in
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==70.3.0 \
|
||||
--hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \
|
||||
--hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc
|
||||
setuptools==76.0.0 \
|
||||
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
|
||||
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
|
||||
# via
|
||||
# -r build/requirements.in
|
||||
# -r build/test-requirements.txt
|
||||
|
@ -12,8 +12,7 @@ portpicker; python_version<"3.13"
|
||||
pytest-xdist
|
||||
wheel
|
||||
rich
|
||||
# TODO(ybaturina): remove setuptools version
|
||||
setuptools<71.0.0
|
||||
setuptools
|
||||
# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement
|
||||
# below.
|
||||
matplotlib~=3.8.4; python_version=="3.10"
|
||||
|
@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str):
|
||||
return major_version
|
||||
|
||||
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str]):
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool):
|
||||
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
|
||||
# Get the index of the "run" parameter. Build options will come after "run" so
|
||||
# we find the index of "run" and filter everything after it.
|
||||
start = bazel_command.index("run")
|
||||
# we find the index of "run" and filter everything after it. If we are using
|
||||
# the new wheel build rule, we will find the index of "build" instead.
|
||||
if use_new_wheel_build_rule:
|
||||
start = bazel_command.index("build")
|
||||
else:
|
||||
start = bazel_command.index("run")
|
||||
jax_configure_bazel_options = ""
|
||||
try:
|
||||
for i in range(start + 1, len(bazel_command)):
|
||||
|
@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
|
||||
arch="amd64"
|
||||
fi
|
||||
|
||||
# Determine the artifact tag flags based on the artifact type. A release
|
||||
# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is
|
||||
# tagged with the release version and a nightly suffix that contains the
|
||||
# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with
|
||||
# the git commit hash of the HEAD of the current branch and the date of the
|
||||
# commit (e.g. 0.5.1.dev20250128+3e75e20c7).
|
||||
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release"
|
||||
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
|
||||
current_date=$(date +%Y%m%d)
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly"
|
||||
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
|
||||
else
|
||||
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
|
||||
# flags in the .bazelrc depending upon the platform we are building for.
|
||||
bazelrc_config="${os}_${arch}"
|
||||
|
||||
# Build the jax artifact
|
||||
if [[ "$artifact" == "jax" ]]; then
|
||||
python -m build --outdir $JAXCI_OUTPUT_DIR
|
||||
# On platforms with no RBE support, we can use the Bazel remote cache. Set
|
||||
# it to be empty by default to avoid unbound variable errors.
|
||||
bazel_remote_cache=""
|
||||
|
||||
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
|
||||
bazelrc_config="rbe_${bazelrc_config}"
|
||||
else
|
||||
bazelrc_config="ci_${bazelrc_config}"
|
||||
|
||||
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
|
||||
# flags in the .bazelrc depending upon the platform we are building for.
|
||||
bazelrc_config="${os}_${arch}"
|
||||
|
||||
# On platforms with no RBE support, we can use the Bazel remote cache. Set
|
||||
# it to be empty by default to avoid unbound variable errors.
|
||||
bazel_remote_cache=""
|
||||
|
||||
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
|
||||
bazelrc_config="rbe_${bazelrc_config}"
|
||||
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
|
||||
# CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache_push"
|
||||
else
|
||||
bazelrc_config="ci_${bazelrc_config}"
|
||||
|
||||
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
|
||||
# CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache_push"
|
||||
else
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache"
|
||||
fi
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use the "_cuda" configs when building the CUDA artifacts.
|
||||
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
|
||||
bazelrc_config="${bazelrc_config}_cuda"
|
||||
fi
|
||||
# Use the "_cuda" configs when building the CUDA artifacts.
|
||||
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
|
||||
bazelrc_config="${bazelrc_config}_cuda"
|
||||
fi
|
||||
|
||||
# Build the artifact.
|
||||
# Build the artifact.
|
||||
python build/build.py build --wheels="$artifact" \
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
$artifact_tag_flags
|
||||
|
||||
# If building release artifacts, we also build a release candidate ("rc")
|
||||
# tagged wheel.
|
||||
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
|
||||
python build/build.py build --wheels="$artifact" \
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
|
||||
fi
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
if [[ "$os" == "linux" ]]; then
|
||||
./ci/utilities/run_auditwheel.sh
|
||||
fi
|
||||
# Move the built artifacts from the Bazel cache directory to the output
|
||||
# directory.
|
||||
if [[ "$artifact" == "jax" ]]; then
|
||||
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
|
||||
else
|
||||
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
fi
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
|
||||
./ci/utilities/run_auditwheel.sh
|
||||
fi
|
||||
|
||||
else
|
||||
|
@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
|
||||
# flag is enabled only for CI builds.
|
||||
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
|
||||
|
||||
# Type of artifacts to build. Valid values are "default", "release", "nightly".
|
||||
# This affects the wheel naming/tag.
|
||||
export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"}
|
||||
|
||||
# When building release artifacts, we build a release candidate wheel ("rc"
|
||||
# tagged wheel) in addition to the release wheel. This environment variable
|
||||
# sets the version of the release candidate ("RC") artifact to build.
|
||||
export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-}
|
||||
|
||||
# #############################################################################
|
||||
# Test script specific environment variables.
|
||||
# #############################################################################
|
||||
@ -65,9 +74,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
|
||||
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
|
||||
# on the system. By default, it is set to match the version of the hermetic
|
||||
# Python used by Bazel for building the wheels.
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
|
||||
# Installs the JAX package in editable mode at the current commit. Enabled by
|
||||
# default. Nightly/Release builds disable this flag in the Github action
|
||||
# workflow files.
|
||||
export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"}
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
@ -19,7 +19,7 @@
|
||||
# avoid using the Windows version of `find` on Msys.
|
||||
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
|
||||
|
||||
if [[ -z "$WHEELS" ]]; then
|
||||
if [[ -z "${WHEELS[@]}" ]]; then
|
||||
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
|
||||
exit 1
|
||||
fi
|
||||
@ -38,10 +38,4 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then
|
||||
"$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}")
|
||||
else
|
||||
"$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}"
|
||||
fi
|
||||
|
||||
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
|
||||
echo "Installing the JAX package at the current commit..."
|
||||
# Install JAX package at the current commit.
|
||||
"$JAXCI_PYTHON" -m uv pip install .
|
||||
fi
|
@ -98,3 +98,6 @@ function retry {
|
||||
|
||||
# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
|
||||
retry "bazel --version"
|
||||
|
||||
# Create the output directory if it doesn't exist.
|
||||
mkdir -p "$JAXCI_OUTPUT_DIR"
|
@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly,
|
||||
and elsewhere.
|
||||
|
||||
At the heart of the project is the [JAX
|
||||
core](http://github.com/google/jax) library, which focuses on the
|
||||
core](http://github.com/jax-ml/jax) library, which focuses on the
|
||||
fundamentals of machine learning and numerical computing, at scale.
|
||||
|
||||
When [developing](#development) the core, we want to maintain agility
|
||||
|
@ -91,8 +91,8 @@ def f(x):
|
||||
jax.debug.print("x: {}", x)
|
||||
return x
|
||||
jax.pmap(f)(xs)
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
# Prints: x: 0.0
|
||||
# x: 1.0
|
||||
# OR
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
|
@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
|
||||
# and then use `updates` instead of `grad` to actually update the params.
|
||||
# (And we'd include `new_optimizer_state` in the output, naturally.)
|
||||
|
||||
new_params = jax.tree_map(
|
||||
new_params = jax.tree.map(
|
||||
lambda param, g: param - g * LEARNING_RATE, params, grad)
|
||||
|
||||
return new_params
|
||||
|
@ -69,7 +69,7 @@ def dots_saveable(prim, *_, **__) -> bool:
|
||||
lax_convolution.conv_general_dilated_p}
|
||||
checkpoint_dots = dots_saveable
|
||||
|
||||
def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
|
||||
def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
|
||||
# This is a useful heuristic for transformers.
|
||||
if prim is lax_internal.dot_general_p:
|
||||
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
|
||||
@ -160,8 +160,8 @@ checkpoint_policies = types.SimpleNamespace(
|
||||
nothing_saveable=nothing_saveable,
|
||||
dots_saveable=dots_saveable,
|
||||
checkpoint_dots=dots_saveable,
|
||||
dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable,
|
||||
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable,
|
||||
dots_with_no_batch_dims_saveable=dots_with_no_batch_dims_saveable,
|
||||
checkpoint_dots_with_no_batch_dims=dots_with_no_batch_dims_saveable,
|
||||
offload_dot_with_no_batch_dims=offload_dot_with_no_batch_dims,
|
||||
save_anything_except_these_names=save_anything_except_these_names,
|
||||
save_any_names_but_these=save_any_names_but_these,
|
||||
@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args):
|
||||
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
|
||||
"`jax.remat` can only take integer values greater than or "
|
||||
"equal to `-len(args)` and less than `len(args)`, but got "
|
||||
f"{static_argnums}")
|
||||
f"{static_argnums}, while `len(args)` = {len(args)}")
|
||||
|
||||
if not static_argnums:
|
||||
return fun, args
|
||||
|
@ -1094,6 +1094,9 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
return f"args{keystr(key_path)}"
|
||||
# args is a tuple, so key_path[0].idx is the index into args.
|
||||
i = key_path[0].idx
|
||||
# This can happen with star arguments (*args)
|
||||
if i >= len(signature_parameters):
|
||||
return f"args{keystr(key_path)}"
|
||||
res = f"argument {signature_parameters[i]}"
|
||||
if len(key_path) > 1:
|
||||
res += keystr(key_path[1:])
|
||||
@ -1135,8 +1138,8 @@ def pmap(
|
||||
fun: Callable,
|
||||
axis_name: AxisName | None = None,
|
||||
*,
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
in_axes: int | None | Sequence[Any] = 0,
|
||||
out_axes: Any = 0,
|
||||
static_broadcasted_argnums: int | Iterable[int] = (),
|
||||
devices: Sequence[xc.Device] | None = None, # noqa: F811
|
||||
backend: str | None = None,
|
||||
@ -2002,8 +2005,8 @@ def vjp(
|
||||
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
|
||||
del reduce_axes
|
||||
check_callable(fun)
|
||||
wrapped_fun = lu.wrap_init(fun,
|
||||
debug_info=debug_info("vjp", fun, primals, {}))
|
||||
wrapped_fun = lu.wrap_init(
|
||||
fun, debug_info=debug_info("vjp", fun, primals, {}))
|
||||
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
|
||||
|
||||
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
|
@ -382,6 +382,9 @@ def is_hashable(arg):
|
||||
return False
|
||||
|
||||
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
||||
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
||||
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
|
||||
@ -389,7 +392,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
||||
# and return the flattened result
|
||||
# TODO(mattjj,phawkins): improve this implementation
|
||||
proxy = object()
|
||||
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
||||
dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves)
|
||||
axes = []
|
||||
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
||||
try:
|
||||
@ -564,8 +567,9 @@ def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]:
|
||||
passed_kwargs = [k for k in ba.kwargs if k in kwargs]
|
||||
if passed_kwargs:
|
||||
raise TypeError(
|
||||
f"keyword arguments ({passed_kwargs}) could not be resolved to "
|
||||
"positions")
|
||||
"The following keyword arguments could not be resolved to positions: "
|
||||
f"{', '.join(passed_kwargs)}"
|
||||
)
|
||||
return ba.args
|
||||
|
||||
|
||||
|
@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings,
|
||||
in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, inline, keep_unused,
|
||||
donated_invars, ctx_mesh, name, inline, keep_unused,
|
||||
compiler_options_kvs):
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
out_shardings=new_out_shardings,
|
||||
in_layouts=new_in_layouts,
|
||||
out_layouts=new_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=new_donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
inline=inline,
|
||||
keep_unused=keep_unused,
|
||||
|
@ -15,6 +15,7 @@
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
@ -72,7 +73,19 @@ def cloud_tpu_init() -> None:
|
||||
|
||||
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
|
||||
libtpu_path = get_tpu_library_path()
|
||||
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
|
||||
num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id()
|
||||
if (
|
||||
tpu_id is not None
|
||||
and tpu_id >= hardware_utils.TpuVersion.v5e
|
||||
and not hardware_utils.transparent_hugepages_enabled()
|
||||
):
|
||||
warnings.warn(
|
||||
'Transparent hugepages are not enabled. TPU runtime startup and'
|
||||
' shutdown time should be significantly improved on TPU v5e and newer.'
|
||||
' If not already set, you may need to enable transparent hugepages in'
|
||||
' your VM image (sudo sh -c "echo always >'
|
||||
' /sys/kernel/mm/transparent_hugepage/enabled")'
|
||||
)
|
||||
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
||||
return
|
||||
|
||||
|
@ -15,8 +15,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import pathlib
|
||||
|
||||
from jax._src import path as pathlib
|
||||
from jax._src import util
|
||||
|
||||
|
||||
|
@ -1815,6 +1815,7 @@ def _make_lengths_same(sharding, ndim):
|
||||
if ndim > len(sharding.spec):
|
||||
return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim))
|
||||
if ndim < len(sharding.spec):
|
||||
assert all(s is None for s in sharding.spec[ndim:])
|
||||
return sharding.with_spec(sharding.spec[:ndim])
|
||||
assert False, "unreachable"
|
||||
|
||||
@ -1840,6 +1841,8 @@ def _maybe_modify_sharding(sharding, ndim):
|
||||
return sharding
|
||||
|
||||
if sharding.mesh._are_all_axes_explicit:
|
||||
if ndim > len(sharding.spec):
|
||||
return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim))
|
||||
return sharding
|
||||
|
||||
out = sharding.with_spec(modify_spec_for_auto_manual(
|
||||
@ -1849,9 +1852,22 @@ def _maybe_modify_sharding(sharding, ndim):
|
||||
out = _make_lengths_same(out, ndim)
|
||||
return out
|
||||
|
||||
def _check_divisibility(sharding, shape):
|
||||
mesh = sharding.mesh
|
||||
for dim, (spec, sh) in enumerate(zip(sharding.spec, shape)):
|
||||
if spec is None:
|
||||
continue
|
||||
spec = spec if isinstance(spec, tuple) else (spec,)
|
||||
size = math.prod(mesh.shape[s] for s in spec)
|
||||
_, remainder = divmod(sh, size)
|
||||
if remainder != 0:
|
||||
raise ValueError(
|
||||
f"Sharding spec {spec} implies that array axis {dim} is partitioned"
|
||||
f" {size} times, but does not evenly divide the dimension size {sh}."
|
||||
f" Got shape: {shape} and sharding {sharding}")
|
||||
|
||||
@cache(max_size=4096, trace_context_in_key=True)
|
||||
def get_sharding(sharding, ndim):
|
||||
def get_sharding(sharding, shape):
|
||||
"""Modifies and checks the sharding.
|
||||
|
||||
Some modifications/checks include:
|
||||
@ -1860,6 +1876,7 @@ def get_sharding(sharding, ndim):
|
||||
* Checking for len(spec)-ndim match
|
||||
* Checking if the mesh is an AbstractMesh.
|
||||
"""
|
||||
ndim = len(shape)
|
||||
if sharding is None:
|
||||
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
|
||||
|
||||
@ -1871,6 +1888,7 @@ def get_sharding(sharding, ndim):
|
||||
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError("Mesh of an aval must be an AbstractMesh. "
|
||||
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
|
||||
_check_divisibility(out_s, shape)
|
||||
return out_s
|
||||
|
||||
|
||||
@ -1882,7 +1900,7 @@ class ShapedArray(UnshapedArray):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.sharding = get_sharding(sharding, len(self.shape))
|
||||
self.sharding = get_sharding(sharding, self.shape)
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
if shape is None:
|
||||
@ -2489,8 +2507,8 @@ class MapPrimitive(Primitive):
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
subfun = lu.hashable_partial(
|
||||
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
axes = new_params.pop('out_axes')
|
||||
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
@ -333,7 +333,7 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
|
||||
|
||||
|
||||
def check_is_flash_attention(
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False,
|
||||
is_fp8=False):
|
||||
# Extract sequence length (T) and head dim (H) based on layout
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
|
@ -143,7 +143,15 @@ class custom_vmap:
|
||||
def __call__(self, *args, **kwargs):
|
||||
debug_fun = api_util.debug_info("custom_vmap fun", self.fun,
|
||||
args, kwargs)
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
try:
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"The input arguments to the custom_vmap-decorated function "
|
||||
f"{debug_fun.func_name} could not be resolved to positional-only "
|
||||
f"arguments. Binding failed with the error:\n{e}"
|
||||
) from e
|
||||
|
||||
if not self.vmap_rule:
|
||||
raise AttributeError(
|
||||
f"No batching rule defined for custom_vmap function {debug_fun.func_name} "
|
||||
|
@ -133,7 +133,15 @@ class custom_dce:
|
||||
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
try:
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"The input arguments to the custom_dce-decorated function "
|
||||
f"{debug.func_name} could not be resolved to positional-only "
|
||||
f"arguments. Binding failed with the error:\n{e}"
|
||||
) from e
|
||||
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
for i in static_argnums:
|
||||
|
@ -250,7 +250,15 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
|
||||
raise AttributeError(msg)
|
||||
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
try:
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"The input arguments to the custom_jvp-decorated function "
|
||||
f"{primal_name} could not be resolved to positional-only arguments. "
|
||||
f"Binding failed with the error:\n{e}"
|
||||
) from e
|
||||
|
||||
if self.nondiff_argnums:
|
||||
nondiff_argnums = set(self.nondiff_argnums)
|
||||
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
|
||||
@ -634,7 +642,16 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
if not self.fwd or not self.bwd:
|
||||
msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp."
|
||||
raise AttributeError(msg)
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
|
||||
try:
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"The input arguments to the custom_vjp-decorated function "
|
||||
f"{debug_fun.func_name} could not be resolved to positional-only "
|
||||
f"arguments. Binding failed with the error:\n{e}"
|
||||
) from e
|
||||
|
||||
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
|
||||
static_argnums=self.nondiff_argnums)
|
||||
# TODO(necula): figure out how to construct the debug_bwd args
|
||||
@ -1238,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
def _maybe_perturbed(x: Any) -> bool:
|
||||
# False if x can't represent an AD-perturbed value (i.e. a value
|
||||
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
|
||||
# See https://github.com/google/jax/issues/6415 for motivation.
|
||||
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
|
||||
if not isinstance(x, core.Tracer):
|
||||
# If x is not a Tracer, it can't be perturbed.
|
||||
return False
|
||||
|
@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
|
||||
*tiled_args
|
||||
)
|
||||
if closed_jaxpr.out_avals != tiled_results:
|
||||
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
|
||||
[(t.shape, t.dtype) for t in tiled_results]):
|
||||
raise ValueError(
|
||||
"Mismatch in result shapes. %s vs %s"
|
||||
% (repr(closed_jaxpr.out_avals), repr(tiled_results))
|
||||
|
@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
|
||||
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
|
||||
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
|
||||
|
||||
# fp4 support
|
||||
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
|
||||
float4_e2m1fn: type[np.generic] | None = None
|
||||
|
||||
_float4_e2m1fn_dtype: np.dtype | None = None
|
||||
|
||||
def supports_inf(dtype: DTypeLike) -> bool:
|
||||
"""Return true if the dtype supports infinity, else return False."""
|
||||
typ = np.dtype(dtype).type
|
||||
@ -144,6 +150,8 @@ _float8_dtypes = [
|
||||
_float8_e5m2fnuz_dtype,
|
||||
]
|
||||
|
||||
_float4_dtypes: list[np.dtype] = []
|
||||
|
||||
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
|
||||
if hasattr(ml_dtypes, "float8_e4m3"):
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"):
|
||||
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
if hasattr(ml_dtypes, "float4_e2m1fn"):
|
||||
float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
|
||||
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
|
||||
# 2-bit integer support
|
||||
int2: type[np.generic] | None = None
|
||||
@ -716,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
|
||||
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
|
||||
"implicit promotion. If you'd like your inputs to be promoted to another type, "
|
||||
"you can do so explicitly using e.g. x.astype('float32')")
|
||||
elif any(n in _float4_dtypes for n in nodes):
|
||||
msg = (
|
||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||
"promotion path. To avoid unintended promotion, 4-bit floats do not support "
|
||||
"implicit promotion. If you'd like your inputs to be promoted to another type, "
|
||||
"you can do so explicitly using e.g. x.astype('float32')")
|
||||
elif any(n in _intn_dtypes for n in nodes):
|
||||
msg = (
|
||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||
|
@ -33,12 +33,11 @@ class JaxValueError(ValueError):
|
||||
"""Exception raised for failed runtime error checks in JAX."""
|
||||
|
||||
|
||||
#: The default error code for no error.
|
||||
#:
|
||||
#: This value is chosen because we can use `jnp.min()` to obtain the
|
||||
#: first error when performing reductions.
|
||||
_NO_ERROR = jnp.iinfo(jnp.uint32).max
|
||||
"""The default error code for no error.
|
||||
|
||||
We choose this value because when performing reductions, we can use `min` to
|
||||
obtain the smallest error code.
|
||||
"""
|
||||
|
||||
|
||||
_error_list_lock = threading.Lock()
|
||||
@ -62,7 +61,7 @@ def _initialize_error_code_ref() -> None:
|
||||
|
||||
|
||||
def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
"""Set error if pred is true.
|
||||
"""Set error if any element of pred is true.
|
||||
|
||||
If the error is already set, the new error will be ignored. It will not
|
||||
override the existing error.
|
||||
@ -74,7 +73,7 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
traceback = source_info_util.current().traceback
|
||||
assert traceback is not None
|
||||
with _error_list_lock:
|
||||
new_error_code = len(_error_list)
|
||||
new_error_code = jnp.uint32(len(_error_list))
|
||||
_error_list.append((msg, traceback))
|
||||
|
||||
pred = pred.any()
|
||||
@ -86,18 +85,26 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
|
||||
|
||||
|
||||
def raise_if_error() -> None:
|
||||
"""Raise error if an error is set."""
|
||||
if _error_storage.ref is None:
|
||||
return # if not initialized, do nothing
|
||||
"""Raise error if an error is set.
|
||||
|
||||
This function should be called after the computation is finished. It should
|
||||
not be called within a traced context, such as within a jitted function."
|
||||
"""
|
||||
if _error_storage.ref is None: # if not initialized, do nothing
|
||||
return
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
if isinstance(error_code, core.Tracer):
|
||||
raise ValueError(
|
||||
"raise_if_error() should not be called within a traced context, such as"
|
||||
" within a jitted function."
|
||||
)
|
||||
if error_code == jnp.uint32(_NO_ERROR):
|
||||
return
|
||||
try:
|
||||
msg, traceback = _error_list[error_code]
|
||||
exc = JaxValueError(msg)
|
||||
traceback = traceback.as_python_traceback()
|
||||
filtered_traceback = traceback_util.filter_traceback(traceback)
|
||||
raise exc.with_traceback(filtered_traceback)
|
||||
finally:
|
||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
||||
|
||||
msg, traceback = _error_list[error_code]
|
||||
exc = JaxValueError(msg)
|
||||
traceback = traceback.as_python_traceback()
|
||||
filtered_traceback = traceback_util.filter_traceback(traceback)
|
||||
raise exc.with_traceback(filtered_traceback)
|
||||
|
@ -75,6 +75,7 @@ enum DType: byte {
|
||||
f8_e5m2 = 20,
|
||||
f8_e5m2fnuz = 21,
|
||||
f8_e8m0fnu = 25,
|
||||
f4_e2m1fn = 26,
|
||||
}
|
||||
|
||||
table AbstractValue {
|
||||
|
@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
|
||||
if dtypes._float8_e8m0fnu_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
|
||||
if dtypes._float4_e2m1fn_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
|
||||
_dtype_kind_to_dtype = {
|
||||
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
|
||||
}
|
||||
|
@ -62,6 +62,7 @@ class DType(object):
|
||||
f8_e5m2fnuz = 21
|
||||
f0 = 22
|
||||
f8_e8m0fnu = 25
|
||||
f4_e2m1fn = 26
|
||||
|
||||
|
||||
class ShardingKind(object):
|
||||
|
@ -12,25 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import os
|
||||
import pathlib
|
||||
import glob
|
||||
|
||||
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
|
||||
_TPU_PCI_DEVICE_IDS = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# No public name (plc)
|
||||
'0x0056',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
# TPU v5p
|
||||
'0x0062',
|
||||
# TPU v5e
|
||||
'0x0063',
|
||||
# TPU v6e
|
||||
'0x006f',
|
||||
]
|
||||
|
||||
_NVIDIA_GPU_DEVICES = [
|
||||
'/dev/nvidia0',
|
||||
@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [
|
||||
'/dev/dxg', # WSL2
|
||||
]
|
||||
|
||||
|
||||
class TpuVersion(enum.IntEnum):
|
||||
# TPU v2, v3
|
||||
v2 = 0
|
||||
v3 = 1
|
||||
# No public name (plc)
|
||||
plc = 2
|
||||
# TPU v4
|
||||
v4 = 3
|
||||
# TPU v5p
|
||||
v5p = 4
|
||||
# TPU v5e
|
||||
v5e = 5
|
||||
# TPU v6e
|
||||
v6e = 6
|
||||
|
||||
|
||||
_TPU_PCI_DEVICE_IDS = {
|
||||
'0x0027': TpuVersion.v3,
|
||||
'0x0056': TpuVersion.plc,
|
||||
'0x005e': TpuVersion.v4,
|
||||
'0x0062': TpuVersion.v5p,
|
||||
'0x0063': TpuVersion.v5e,
|
||||
'0x006f': TpuVersion.v6e,
|
||||
}
|
||||
|
||||
def num_available_tpu_chips_and_device_id():
|
||||
"""Returns the device id and number of TPU chips attached through PCI."""
|
||||
num_chips = 0
|
||||
device_id = ''
|
||||
tpu_version = None
|
||||
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
|
||||
vendor_id = pathlib.Path(vendor_path).read_text().strip()
|
||||
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
|
||||
@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id():
|
||||
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
|
||||
device_id = pathlib.Path(device_path).read_text().strip()
|
||||
if device_id in _TPU_PCI_DEVICE_IDS:
|
||||
tpu_version = _TPU_PCI_DEVICE_IDS[device_id]
|
||||
num_chips += 1
|
||||
|
||||
return num_chips, device_id
|
||||
return num_chips, tpu_version
|
||||
|
||||
|
||||
def has_visible_nvidia_gpu() -> bool:
|
||||
"""True if there's a visible nvidia gpu available on device, False otherwise."""
|
||||
|
||||
return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
|
||||
|
||||
|
||||
def transparent_hugepages_enabled() -> bool:
|
||||
# See https://docs.kernel.org/admin-guide/mm/transhuge.html for more
|
||||
# information about transparent huge pages.
|
||||
path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled')
|
||||
return path.exists() and path.read_text().strip() == '[always] madvise never'
|
||||
|
@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
partition_list)
|
||||
partition_list, subs_list2)
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
*primals, **params):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace(debug_info)
|
||||
tangent_trace.tag = _tag
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
|
||||
tracers = [LinearizeTracer(linearize_trace, p,
|
||||
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
|
||||
@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
|
||||
residual_avals = map(get_aval, consts)
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
_store.store((residual_avals, nzs_out, jaxpr))
|
||||
return tuple(consts) + tuple(out_primals)
|
||||
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
|
||||
getattr(c._trace, 'tag', None) is _tag) for c in consts]
|
||||
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
|
||||
res, env = partition_list(which_env, consts)
|
||||
residual_avals = map(get_aval, res)
|
||||
# Which residuals are just forwarded inputs? Check object id.
|
||||
id_map = {id(p): i for i, p in enumerate(primals)}
|
||||
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
||||
# Which residuals are already primal outputs? Check object id.
|
||||
id_map = {id(p): i for i, p in enumerate(out_primals)}
|
||||
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
||||
# Prune residuals not to include forwarded primal inputs or outputs.
|
||||
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
|
||||
_store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd))
|
||||
return *res, *out_primals
|
||||
|
||||
@lu.transformation2
|
||||
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
|
||||
@ -157,6 +170,7 @@ def _linearize_jaxpr(
|
||||
primal_trace = pe.DynamicJaxprTrace(dbg)
|
||||
tangent_trace = pe.DynamicJaxprTrace(dbg)
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
tangent_trace.tag = lin_trace.tag
|
||||
|
||||
def new_arg(trace, primal_aval, nz):
|
||||
primal = primal_trace.new_arg(primal_aval)
|
||||
@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
|
||||
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
|
||||
tangent_trace.tag = linearize_trace.tag
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
tracers = [t.full_lower() for t in tracers]
|
||||
with (core.set_current_trace(linearize_trace, check_leaks=True),
|
||||
@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
|
||||
tangent_trace.invalidate()
|
||||
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
|
||||
jaxpr, [True] * len(jaxpr.outvars),
|
||||
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
|
||||
consts = [c for c, used in zip(consts, used_consts) if used]
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
|
||||
pe.PartialVal.known(zeros_like_aval(t.aval))
|
||||
for t, nz in zip(out_tangents, out_nzs)]
|
||||
@ -330,12 +349,34 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
# forces primal_in to contain UndefinedPrimals for tangent values!
|
||||
map(write_primal, jaxpr.invars, primals_in)
|
||||
|
||||
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
|
||||
# only operate on primals. This is required to support primitives with
|
||||
# linearization rules that include computations on the residuals.
|
||||
lin_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
# TODO (dfm): The effects check is probably stricter than necessary.
|
||||
# Consider adding an allowlist of effects here.
|
||||
if jaxpr.effects or any(
|
||||
type(x) is not Literal and x not in primal_env for x in eqn.invars):
|
||||
lin_eqns.append(eqn)
|
||||
continue
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
traceback = eqn.source_info.traceback
|
||||
with source_info_util.user_context(
|
||||
traceback, name_stack=name_stack), eqn.ctx.manager:
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
|
||||
ct_env: dict[Any, Any] = {}
|
||||
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with ctx:
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
for eqn in lin_eqns[::-1]:
|
||||
if eqn.primitive.ref_primitive:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
val_var, = eqn.invars
|
||||
@ -586,7 +627,7 @@ def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = get_aval(primal).strip_weak_type()
|
||||
tangent_aval = get_aval(tangent).strip_weak_type()
|
||||
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
|
||||
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape)
|
||||
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
||||
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
||||
|
||||
@ -641,6 +682,7 @@ class LinearizeTrace(Trace):
|
||||
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
|
||||
dict(symbolic_zeros=symbolic_zeros))
|
||||
|
||||
@partial(lu.wrap_init, debug_info=f_jvp.debug_info)
|
||||
def _f_jvp(primals, tangents):
|
||||
outs = f_jvp.call_wrapped(*primals, *tangents)
|
||||
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
||||
@ -651,7 +693,7 @@ class LinearizeTrace(Trace):
|
||||
nonzeros_in = [type(t) is not Zero for t in tangents_in]
|
||||
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
|
||||
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
|
||||
f_jvp.debug_info, primals_in, {})
|
||||
primals_in, {})
|
||||
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
tangents_out = linearized(residuals, *tangents_in)
|
||||
@ -690,53 +732,64 @@ class LinearizeTrace(Trace):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
|
||||
f.debug_info)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(
|
||||
f, self.tag, nzs_in, f.debug_info)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
out_axes = params['out_axes_thunk']()
|
||||
return (*(0 for _ in residual_avals), *out_axes)
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
out_axes = out_axes_thunk()
|
||||
return (*(0 for _ in range(num_res_out)), *out_axes)
|
||||
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
||||
else:
|
||||
primal_params = params
|
||||
|
||||
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
non_fwd_res = all_primal_results[:num_res_out]
|
||||
primals_out = all_primal_results[num_res_out:]
|
||||
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
||||
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
in_axes = params['in_axes']
|
||||
out_axes = params['out_axes_thunk']()
|
||||
residual_avals = map(get_aval, residuals)
|
||||
new_in_axes = (*(0 for _ in residual_avals),
|
||||
residual_axes = [in_axes[f1] if f1 is not None else
|
||||
out_axes[f2] if f2 is not None else
|
||||
0 for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
new_in_axes = (*residual_axes, *(None for _ in range(len(env))),
|
||||
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
|
||||
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
|
||||
# NOTE: This assumes that the output tangents being zero is a
|
||||
# deterministic function of which input tangents were zero.
|
||||
@as_hashable_function(closure=(new_out_axes))
|
||||
@as_hashable_function(closure=new_out_axes)
|
||||
def new_out_axes_thunk():
|
||||
return new_out_axes
|
||||
params = dict(params,
|
||||
in_axes=new_in_axes,
|
||||
out_axes_thunk=new_out_axes_thunk)
|
||||
params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
||||
|
||||
update_params = call_linearize_param_updaters.get(call_primitive)
|
||||
new_params = update_params(params, residual_avals, nzs_in) if update_params else params
|
||||
num_new_args = len(residuals) + len(env)
|
||||
new_params = update_params(params, num_new_args, nzs_in) if update_params else params
|
||||
num_residuals = len(residual_avals)
|
||||
|
||||
@as_hashable_function(closure=(num_residuals, lin_jaxpr))
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
consts = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
|
||||
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)
|
||||
# TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to
|
||||
# avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses.
|
||||
# Remove when we replace the pmap implementation.
|
||||
f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive)
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = call_primitive.bind_with_trace(
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent,
|
||||
debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *nz_tangents_in), new_params)
|
||||
self.tangent_trace,
|
||||
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *env, *nz_tangents_in), new_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
@ -762,14 +815,14 @@ def fallback_linearize_rule(_prim: core.Primitive,
|
||||
msg = f"Differentiation rule for '{_prim}' not implemented"
|
||||
raise NotImplementedError(msg)
|
||||
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
|
||||
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
|
||||
debug_jvp, primals, params)
|
||||
return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp),
|
||||
_prim.multiple_results, _nonzeros, False, False,
|
||||
primals, params)
|
||||
|
||||
def linearize_from_jvp(jvp: Callable,
|
||||
def linearize_from_jvp(jvp: lu.WrappedFun,
|
||||
multiple_results: bool,
|
||||
nonzeros: Sequence[bool],
|
||||
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
|
||||
debug_info: core.DebugInfo,
|
||||
primals, params):
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
@ -792,13 +845,18 @@ def linearize_from_jvp(jvp: Callable,
|
||||
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
|
||||
for aval, nz in zip(tangent_avals, nonzeros))
|
||||
with core.set_current_trace(trace):
|
||||
out_primals, out_tangents = jvp(primals, tangent_args, **params)
|
||||
out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params)
|
||||
|
||||
if not multiple_results:
|
||||
out_primals = [out_primals]
|
||||
out_tangents = [out_tangents]
|
||||
|
||||
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
|
||||
if any(p is None for p in out_primals):
|
||||
raise ValueError(
|
||||
"Linearization failed to produce known values for all output primals. "
|
||||
"This is typically caused by attempting to differentiate a function "
|
||||
"uses an operation that does not support reverse-mode autodiff.")
|
||||
|
||||
out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known()
|
||||
for t in out_tangents]
|
||||
@ -806,7 +864,7 @@ def linearize_from_jvp(jvp: Callable,
|
||||
out_nz_tracers = [trace.to_jaxpr_tracer(r)
|
||||
for (r, nz) in zip(out_tangents, out_nzs) if nz]
|
||||
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info)
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
|
||||
|
||||
def linearized(residuals, *tangents):
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
|
||||
@ -973,9 +1031,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
else:
|
||||
consts = ()
|
||||
all_args, in_tree_def = tree_flatten((consts, args, ct))
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun = lu.hashable_partial(lu.wrap_init(
|
||||
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
@ -1013,9 +1070,8 @@ def map_transpose(primitive: core.Primitive, params,
|
||||
call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
# TODO(necula): use the right debug_info for the backwards pass
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun = lu.hashable_partial(lu.wrap_init(
|
||||
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
|
||||
fun, nz_arg_cts = nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
@ -1083,8 +1139,8 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
|
||||
nonzeros)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(
|
||||
jvp(f, instantiate=instantiate, transform_stack=False), nonzeros)
|
||||
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
|
||||
|
@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None:
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
|
||||
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
|
||||
|
||||
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
|
||||
if isinstance(dtype, core.bint):
|
||||
# TODO Support different-size underlying dtypes to take advantage of the
|
||||
@ -939,7 +942,7 @@ def sharded_aval(aval: core.AbstractValue,
|
||||
return aval
|
||||
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
|
||||
raise NotImplementedError
|
||||
return aval.update(sharding.shard_shape(aval.shape)) # type: ignore
|
||||
return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore
|
||||
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
|
@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache, subs_list)
|
||||
as_hashable_function, weakref_lru_cache, subs_list,
|
||||
HashableFunction)
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -837,6 +838,11 @@ def tracers_to_jaxpr(
|
||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
@weakref_lru_cache
|
||||
def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
|
||||
constvars, envvars = partition_list(which, jaxpr.constvars)
|
||||
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])
|
||||
|
||||
@weakref_lru_cache
|
||||
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
"""Moves the constvars to the start of invars."""
|
||||
@ -1840,7 +1846,7 @@ def _inline_literals(
|
||||
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
__slots__ = ("frame", "tag")
|
||||
|
||||
def __init__(self, debug_info: core.DebugInfo):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
|
||||
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun,
|
||||
tracers: Sequence[core.Tracer], params):
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
in_avals = [t.aval for t in tracers]
|
||||
axis_name, axis_size = params['axis_name'], params['axis_size']
|
||||
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
|
||||
if in_axis is not None else a
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals)
|
||||
jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts)
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace(
|
||||
return tracer
|
||||
return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env
|
||||
else new_tracer(x) for x in jaxpr.outvars]
|
||||
|
||||
# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the
|
||||
# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's
|
||||
# handling of pmap. Remove when we replace the pmap implementation.
|
||||
def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]:
|
||||
if (not f.transforms and type(f.f) is HashableFunction and
|
||||
getattr(f.f, '_pmap_tag', None)):
|
||||
_, jaxpr = f.f.closure
|
||||
return convert_constvars_jaxpr(jaxpr), []
|
||||
return jaxpr, consts
|
||||
|
@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents):
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
|
||||
def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
|
||||
def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents):
|
||||
donated_invars_prev = params['donated_invars']
|
||||
donated_invars = (*(False for _ in residual_avals),
|
||||
donated_invars = (*(False for _ in range(num_new_inputs)),
|
||||
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
|
||||
return dict(params, donated_invars=donated_invars)
|
||||
|
||||
@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum):
|
||||
elif self.name == 'OUT_SHARDING':
|
||||
return 'explicit output sharding'
|
||||
elif self.name == 'CONTEXT_DEVICES':
|
||||
return 'devices'
|
||||
return 'context mesh'
|
||||
return f'{self.name}'
|
||||
|
||||
|
||||
@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys:
|
||||
in_layouts_leaves: tuple[Any, ...] | None = None
|
||||
out_layouts_treedef: PyTreeDef | None = None
|
||||
out_layouts_leaves: tuple[Any, ...] | None = None
|
||||
use_resource_env: bool = False
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -27,7 +27,6 @@ from jax._src.lax import lax
|
||||
from jax._src import effects
|
||||
from jax._src import ad_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.util import weakref_lru_cache, safe_map, partition_list
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
|
||||
@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
# b[] <- 2.0
|
||||
# in () }
|
||||
canonical_ref_indices = []
|
||||
canonical_non_ref_indices = []
|
||||
canonical_refs: list[Any] = []
|
||||
tracer_id_to_canonical_id = {}
|
||||
all_nonref_consts = []
|
||||
canonical_non_refs: list[Any] = []
|
||||
tracer_id_to_canonical_ref_id = {}
|
||||
tracer_id_to_canonical_non_ref_id = {}
|
||||
canonical_ref_avals = []
|
||||
all_nonref_const_avals = []
|
||||
canonical_non_ref_avals = []
|
||||
for consts, consts_avals in zip(all_consts, all_const_avals):
|
||||
ref_indices = []
|
||||
nonref_consts = []
|
||||
nonref_const_avals = []
|
||||
non_ref_indices = []
|
||||
for c, aval in zip(consts, consts_avals):
|
||||
tracer_id = id(c)
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
tracer_id = id(c)
|
||||
if tracer_id not in tracer_id_to_canonical_id:
|
||||
if tracer_id not in tracer_id_to_canonical_ref_id:
|
||||
canonical_id = len(canonical_refs)
|
||||
canonical_refs.append(c)
|
||||
tracer_id_to_canonical_id[tracer_id] = canonical_id
|
||||
tracer_id_to_canonical_ref_id[tracer_id] = canonical_id
|
||||
canonical_ref_avals.append(aval)
|
||||
canonical_id = tracer_id_to_canonical_id[tracer_id]
|
||||
canonical_id = tracer_id_to_canonical_ref_id[tracer_id]
|
||||
ref_indices.append(canonical_id)
|
||||
else:
|
||||
nonref_consts.append(c)
|
||||
nonref_const_avals.append(aval)
|
||||
all_nonref_consts.append(nonref_consts)
|
||||
all_nonref_const_avals.append(tuple(nonref_const_avals))
|
||||
if tracer_id not in tracer_id_to_canonical_non_ref_id:
|
||||
canonical_id = len(canonical_non_refs)
|
||||
canonical_non_refs.append(c)
|
||||
tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id
|
||||
canonical_non_ref_avals.append(aval)
|
||||
canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id]
|
||||
non_ref_indices.append(canonical_id)
|
||||
canonical_ref_indices.append(tuple(ref_indices))
|
||||
canonical_non_ref_indices.append(tuple(non_ref_indices))
|
||||
|
||||
consts = [*canonical_refs, *util.concatenate(all_nonref_consts)]
|
||||
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,))
|
||||
consts = [*canonical_refs, *canonical_non_refs]
|
||||
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
|
||||
for i, jaxpr in enumerate(jaxprs))
|
||||
return jaxprs, consts, all_out_trees
|
||||
|
||||
@weakref_lru_cache
|
||||
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
|
||||
all_nonref_const_avals):
|
||||
canonical_non_ref_avals, canonical_non_ref_indices):
|
||||
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
|
||||
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
|
||||
newvar = core.gensym(suffix='_')
|
||||
unused_const_vars = [tuple(map(newvar, const_avals))
|
||||
for const_avals in all_nonref_const_avals]
|
||||
padded_ref_constvars = map(newvar, canonical_ref_avals)
|
||||
padded_non_ref_constvars = map(newvar, canonical_non_ref_avals)
|
||||
for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars):
|
||||
padded_ref_constvars[canonical_id] = ref_var
|
||||
const_prefix = util.concatenate(unused_const_vars[:i])
|
||||
const_suffix = util.concatenate(unused_const_vars[i + 1:])
|
||||
constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars,
|
||||
*const_suffix]
|
||||
for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars):
|
||||
padded_non_ref_constvars[canonical_id] = non_ref_var
|
||||
constvars = [*padded_ref_constvars, *padded_non_ref_constvars]
|
||||
jaxpr = jaxpr.replace(constvars=constvars)
|
||||
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
|
@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
num_consts = len(consts)
|
||||
out_ = iter(out)
|
||||
|
||||
all_inputs = [*consts, *ops]
|
||||
out = [
|
||||
next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts])
|
||||
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
|
||||
for fwd in in_fwd
|
||||
]
|
||||
assert next(out_, None) is None
|
||||
|
@ -443,42 +443,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
num_trips, remainder = divmod(length, unroll)
|
||||
|
||||
if unroll != 1 and num_trips == 1 and remainder == 0:
|
||||
# In that case, we explicitly want to fully unroll the loop. Put everything
|
||||
# into the remainder block and avoid lowering to a while loop.
|
||||
num_trips, remainder = 0, length
|
||||
if unroll == 1:
|
||||
xss = xs_
|
||||
yss = _map(partial(_empty_array, (length,), None), y_avals)
|
||||
yss = _map(partial(_empty_array, (length,), (None,)), y_avals)
|
||||
else:
|
||||
if remainder:
|
||||
if not reverse:
|
||||
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
|
||||
else:
|
||||
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
|
||||
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
|
||||
yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals)
|
||||
if num_trips:
|
||||
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
|
||||
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
|
||||
else:
|
||||
yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals)
|
||||
|
||||
def cond_fun(while_carry):
|
||||
i, _, _ = while_carry
|
||||
return i < num_trips
|
||||
def body_fun(while_carry):
|
||||
i_, carry, yss = while_carry
|
||||
i = num_trips - i_ - 1 if reverse else i_
|
||||
xs = [
|
||||
slicing.dynamic_index_in_dim(
|
||||
xs, i, keepdims=False, allow_negative_indices=False
|
||||
)
|
||||
for xs in xss
|
||||
]
|
||||
carry, ys = inner(unroll, carry, xs)
|
||||
yss = [
|
||||
slicing.dynamic_update_index_in_dim(
|
||||
ys, upd, i, 0, allow_negative_indices=False
|
||||
)
|
||||
for ys, upd in zip(yss, ys)
|
||||
]
|
||||
return i_ + 1, carry, yss
|
||||
def inner(n, carry, xs):
|
||||
ys = []
|
||||
if unroll == 1:
|
||||
@ -493,10 +477,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
ys = list(reversed(ys)) if reverse else ys
|
||||
return carry, _map(_stack, zip(*ys))
|
||||
|
||||
def body_fun(while_carry):
|
||||
i_, carry, yss = while_carry
|
||||
i = num_trips - i_ - 1 if reverse else i_
|
||||
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False,
|
||||
allow_negative_indices=False)
|
||||
for xs in xss]
|
||||
carry, ys = inner(unroll, carry, xs)
|
||||
yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0,
|
||||
allow_negative_indices=False)
|
||||
for y, upd in zip(yss, ys)]
|
||||
return i_ + 1, carry, yss
|
||||
|
||||
def cond_fun(while_carry):
|
||||
i, _, _ = while_carry
|
||||
return i < num_trips
|
||||
|
||||
if num_trips:
|
||||
i = lax._const(num_trips, 0)
|
||||
_, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss))
|
||||
if unroll != 1:
|
||||
if unroll != 1 and num_trips != 0:
|
||||
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
|
||||
else:
|
||||
ys = yss
|
||||
@ -512,7 +512,7 @@ def _split_leading(sz, x):
|
||||
def _concat(a, b): return lax.concatenate([a, b], 0)
|
||||
|
||||
def _empty_array(prefix, length_spec, aval):
|
||||
sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec))
|
||||
sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec))
|
||||
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
|
||||
out_sharding=sharding)
|
||||
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
@ -2227,10 +2228,122 @@ def ragged_dot(
|
||||
Results:
|
||||
(m, n) shaped array with preferred_element_type element type.
|
||||
"""
|
||||
return ragged_dot_p.bind(lhs, rhs, group_sizes,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset)
|
||||
return ragged_dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RaggedDotDimensionNumbers():
|
||||
"""Describes ragged, group, and dot dimensions for ragged dot general.
|
||||
|
||||
Args:
|
||||
dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
||||
rhs_batch_dims))`.
|
||||
lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
|
||||
dimensions.
|
||||
rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
|
||||
dimensions.
|
||||
"""
|
||||
dot_dimension_numbers: DotDimensionNumbers
|
||||
lhs_ragged_dimensions: Sequence[int]
|
||||
rhs_group_dimensions: Sequence[int]
|
||||
|
||||
def __init__(
|
||||
self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
|
||||
):
|
||||
super().__setattr__(
|
||||
'dot_dimension_numbers',
|
||||
tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
|
||||
)
|
||||
super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
|
||||
super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
|
||||
|
||||
|
||||
def _from_maybe_ragged(
|
||||
dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
|
||||
) -> DotDimensionNumbers:
|
||||
return (
|
||||
dot_dimension_numbers.dot_dimension_numbers
|
||||
if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
|
||||
else dot_dimension_numbers
|
||||
)
|
||||
|
||||
|
||||
# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
|
||||
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
)
|
||||
|
||||
|
||||
def ragged_dot_general(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
group_offset: Array | None = None,
|
||||
) -> Array:
|
||||
"""Ragged matrix multiplication.
|
||||
|
||||
Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
|
||||
a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
|
||||
``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
|
||||
``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
|
||||
most one group dimension.
|
||||
|
||||
Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
|
||||
three modes, depending on the kind of the lhs ragged dimension:
|
||||
1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
|
||||
Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
|
||||
and `x...` are the lhs non-contracting dims outer to the ragged dim.
|
||||
2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
|
||||
Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
|
||||
``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
|
||||
3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
|
||||
Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
|
||||
``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
|
||||
If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
|
||||
to the rules above.
|
||||
|
||||
Args:
|
||||
lhs: an array
|
||||
rhs: an array
|
||||
group_sizes: an array with integer element type
|
||||
ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
|
||||
specify the dot dimension numbers, lhs ragged dimension, and rhs group
|
||||
dimension.
|
||||
precision: Optional. Consistent with precision argument for
|
||||
:func:`jax.lax.dot`.
|
||||
preferred_element_type: Optional. Consistent with precision argument for
|
||||
:func:`jax.lax.dot`.
|
||||
group_offset: Optional. (1,) shaped array that indicates the group in
|
||||
group_sizes to start computing from. If not specified, defaults to [0].
|
||||
|
||||
Results:
|
||||
An array whose shape is the same as that produced by `dot_general`, with an
|
||||
extra leading dimension of size `g` in the case where the lhs ragged
|
||||
dimension is a contracting dimension.
|
||||
"""
|
||||
return ragged_dot_general_p.bind(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
|
||||
|
||||
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
|
||||
@ -3723,10 +3836,11 @@ def _sin_complex(x):
|
||||
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
|
||||
a, b = real(x), imag(x)
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
two = _const(a, 2)
|
||||
sn, cs = sin(a), cos(a)
|
||||
e1m, e2m = expm1(b), expm1(-b)
|
||||
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
||||
re, im = sn * csh, cs * snh
|
||||
e1m, e2m = expm1(b), expm1(neg(b))
|
||||
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
|
||||
re, im = mul(sn, csh), mul(cs, snh)
|
||||
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
|
||||
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
|
||||
|
||||
@ -3736,14 +3850,14 @@ def _sin_lowering(ctx, x):
|
||||
return sine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x)
|
||||
|
||||
def _sin_p_lin(nzs, x):
|
||||
def _sin_lin(nzs, x):
|
||||
nz, = nzs
|
||||
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
|
||||
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
ad.primitive_linearizations[sin_p] = _sin_p_lin
|
||||
ad.primitive_linearizations[sin_p] = _sin_lin
|
||||
mlir.register_lowering(sin_p, _sin_lowering)
|
||||
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
@ -3752,10 +3866,11 @@ def _cos_complex(x):
|
||||
# see also _sin_complex
|
||||
a, b = real(x), imag(x)
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
two = _const(a, 2)
|
||||
sn, cs = sin(a), cos(a)
|
||||
e1m, e2m = expm1(b), expm1(-b)
|
||||
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
||||
re, im = cs * csh, -sn * snh
|
||||
e1m, e2m = expm1(b), expm1(neg(b))
|
||||
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
|
||||
re, im = mul(cs, csh), mul(neg(sn), snh)
|
||||
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
|
||||
|
||||
def _cos_lowering(ctx, x):
|
||||
@ -3769,28 +3884,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
mlir.register_lowering(cos_p, _cos_lowering)
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
||||
|
||||
asin_p = standard_unop(_float | _complex, 'asin')
|
||||
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
||||
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x)))))
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
|
||||
|
||||
acos_p = standard_unop(_float | _complex, 'acos')
|
||||
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
|
||||
ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x))))))
|
||||
mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos))
|
||||
|
||||
def atan_impl(x):
|
||||
return atan2(x, _const(x, 1))
|
||||
|
||||
atan_p = standard_unop(_float | _complex, 'atan')
|
||||
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
||||
ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x))))
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
|
||||
|
||||
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
||||
ad.defjvp(atan2_p,
|
||||
lambda g, x, y: g * (y / (square(x) + square(y))),
|
||||
lambda g, x, y: g * -x / (square(x) + square(y)))
|
||||
lambda g, x, y: mul(g, div(y, add(square(x), square(y)))),
|
||||
lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y)))))
|
||||
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
|
||||
|
||||
sinh_p = standard_unop(_float | _complex, 'sinh')
|
||||
@ -3802,17 +3917,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
|
||||
|
||||
asinh_p = standard_unop(_float | _complex, 'asinh')
|
||||
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
||||
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x)))))
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
|
||||
|
||||
acosh_p = standard_unop(_float | _complex, 'acosh')
|
||||
ad.defjvp(acosh_p,
|
||||
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
||||
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
|
||||
|
||||
atanh_p = standard_unop(_float | _complex, 'atanh')
|
||||
ad.defjvp(atanh_p,
|
||||
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
||||
lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x))))
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
|
||||
|
||||
real_p = unop(_complex_basetype, _complex, 'real')
|
||||
@ -3906,11 +4021,11 @@ def _square_complex(x):
|
||||
a, b = real(x), imag(x)
|
||||
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
|
||||
# where for finite a, 2 * a is non-finite:
|
||||
zero_re = is_finite(a) & (eq(a, b) | eq(a, -b))
|
||||
zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b)))
|
||||
# equivalent to a**2 - b**2 but avoids overflow errors for large a
|
||||
# and large b cases:
|
||||
re = (a - b) * (a + b)
|
||||
im = a * b * 2
|
||||
re = mul(sub(a, b), add(a, b))
|
||||
im = mul(mul(a, b), _const(a, 2))
|
||||
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
|
||||
|
||||
def _square_lower_hlo(ctx, x):
|
||||
@ -4591,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
out_sharding):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
||||
@ -4652,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
|
||||
|
||||
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
|
||||
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
||||
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
|
||||
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
||||
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
|
||||
rhs_group = ()
|
||||
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
|
||||
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
|
||||
rhs_contract_or_batch_or_group = tuple(
|
||||
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
|
||||
)
|
||||
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
|
||||
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
||||
|
||||
|
||||
@ -4721,7 +4841,7 @@ def tuple_delete(tup, idx):
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_sharding):
|
||||
out_sharding, name: str = 'lax.dot_general'):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
del dimension_numbers # unused
|
||||
@ -4742,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
result_dtype = rhs.dtype
|
||||
else:
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise TypeError(
|
||||
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
|
||||
raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
|
||||
result_dtype = lhs.dtype
|
||||
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
|
||||
return _maybe_upcast(result_dtype, preferred_element_type,
|
||||
@ -4882,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
# explicitly present dimensions that this dot_general is zipping together.
|
||||
lbd, rbd = batch_dims
|
||||
assert lbd is not None or rbd is not None
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
|
||||
is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
|
||||
def bump_dims(dims, b):
|
||||
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
||||
|
||||
@ -4906,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
elif (type(rbd) is int and lbd is None):
|
||||
# The right vmapped dimension becomes an additional tensor dimension in the
|
||||
# batched dot_general.
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
rhs_tensor = list(
|
||||
remaining(
|
||||
range(rhs_ndim),
|
||||
rhs_batch,
|
||||
rhs_contract,
|
||||
dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
|
||||
)
|
||||
)
|
||||
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
||||
int(sum(np.less(rhs_tensor, rbd))))
|
||||
rhs_batch = bump_dims(rhs_batch, rbd)
|
||||
@ -4917,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
assert False
|
||||
|
||||
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
||||
if is_ragged_dot:
|
||||
new_dimension_numbers = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=new_dimension_numbers,
|
||||
lhs_ragged_dimensions=bump_dims(
|
||||
dimension_numbers.lhs_ragged_dimensions, lbd
|
||||
),
|
||||
rhs_group_dimensions=bump_dims(
|
||||
dimension_numbers.rhs_group_dimensions, rbd
|
||||
),
|
||||
)
|
||||
return new_dimension_numbers, result_batch_dim
|
||||
|
||||
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
||||
@ -5008,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims):
|
||||
lbd, rbd = batch_dims
|
||||
return (lbd, rbd)
|
||||
|
||||
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
||||
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([2, 0], [1, 0]),
|
||||
([], []),
|
||||
)
|
||||
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([3, 1], [2, 1]),
|
||||
([0], [0]),
|
||||
)
|
||||
|
||||
ad.defbilinear(dot_general_p,
|
||||
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
||||
@ -5184,58 +5311,181 @@ for platform in ["cpu", "tpu"]:
|
||||
platform=platform)
|
||||
|
||||
|
||||
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
|
||||
if len(lhs.shape) == 3:
|
||||
# Batched case
|
||||
b, m, k = lhs.shape
|
||||
b2, group_count, rk, n = rhs.shape
|
||||
b3 = group_sizes.shape[0]
|
||||
if b != b2:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
|
||||
f' {b2}.'
|
||||
)
|
||||
if b3 != b:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
|
||||
f' {b3} and {b}.'
|
||||
)
|
||||
if k != rk:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
|
||||
f' {rk}.'
|
||||
)
|
||||
num_groups = group_sizes.shape[1]
|
||||
if group_count != num_groups:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
|
||||
f' {group_count} and {num_groups}.'
|
||||
)
|
||||
return (b, m, n)
|
||||
class RaggedDotMode(enum.Enum):
|
||||
RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
|
||||
RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]
|
||||
RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n]
|
||||
|
||||
m, k = lhs.shape
|
||||
group_count, rk, n = rhs.shape
|
||||
if k != rk:
|
||||
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
|
||||
num_groups = group_sizes.shape[0]
|
||||
if group_count != num_groups:
|
||||
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
|
||||
return (m, n)
|
||||
|
||||
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
precision, preferred_element_type: DTypeLike | None,
|
||||
**_) -> np.dtype:
|
||||
def _ragged_dot_mode_and_dim(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> tuple[RaggedDotMode, int]:
|
||||
assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
|
||||
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
||||
(lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
|
||||
if lhs_ragged_dim in lhs_noncontracting:
|
||||
mode = RaggedDotMode.RAGGED_NONCONTRACTING
|
||||
elif lhs_ragged_dim in lhs_contracting:
|
||||
mode = RaggedDotMode.RAGGED_CONTRACTING
|
||||
elif lhs_ragged_dim in lhs_batch:
|
||||
mode = RaggedDotMode.RAGGED_BATCH
|
||||
else:
|
||||
raise TypeError(
|
||||
f'lhs_ragged_dim {lhs_ragged_dim} not found in '
|
||||
f'lhs_noncontracting {lhs_noncontracting}, '
|
||||
f'lhs_contracting {lhs_contracting}, or '
|
||||
f'lhs_batch {lhs_batch}.'
|
||||
)
|
||||
return mode, lhs_ragged_dim
|
||||
|
||||
|
||||
def _ragged_dot_mode(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> RaggedDotMode:
|
||||
return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
|
||||
|
||||
|
||||
def _is_ragged_contracting(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> bool:
|
||||
return (
|
||||
_ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
|
||||
== RaggedDotMode.RAGGED_CONTRACTING
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
|
||||
batch, contract = map(list, (batch, contract))
|
||||
noncontract = remaining(range(rank), contract, batch)
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
return batch + noncontract[: noncontract.index(ragged_dim)]
|
||||
case RaggedDotMode.RAGGED_CONTRACTING:
|
||||
return batch + contract[: contract.index(ragged_dim)]
|
||||
case RaggedDotMode.RAGGED_BATCH:
|
||||
return batch[: batch.index(ragged_dim)]
|
||||
|
||||
|
||||
def _ragged_dot_general_shape_rule(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
):
|
||||
def _check_in_range(dim, rank, dim_name, arg_name):
|
||||
if dim < 0 or dim >= rank:
|
||||
raise TypeError(
|
||||
f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
|
||||
f'and less than the number of axes of the {arg_name} value, '
|
||||
f'got {dim} for {arg_name} of rank {rank}.'
|
||||
)
|
||||
|
||||
# Validate the lhs ragged dimension, and find out which mode we're in.
|
||||
if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
|
||||
raise TypeError(
|
||||
'ragged_dot_general expects exactly one lhs ragged dimension.'
|
||||
)
|
||||
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
||||
_check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
|
||||
mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
|
||||
ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
)
|
||||
|
||||
# Validate the shape of group_sizes, if it is something other than [g].
|
||||
if group_sizes.ndim == 0:
|
||||
raise TypeError('expected rank of group_sizes to be >=1.')
|
||||
if group_sizes.ndim != 1:
|
||||
# Construct the expected shape [b...,x...,g] of group_sizes.
|
||||
prefix_dims = _ragged_dot_prefix_dims(
|
||||
mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
|
||||
)
|
||||
expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
|
||||
expected_gs_shape += (group_sizes.shape[-1],)
|
||||
# TODO(pravnar): Permit other broadcastable shapes.
|
||||
if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
|
||||
raise TypeError(
|
||||
'expected group_sizes to have shape '
|
||||
f'{expected_gs_shape}, got {group_sizes.shape}.'
|
||||
)
|
||||
num_groups = group_sizes.shape[-1]
|
||||
|
||||
# Validate properties of the rhs group dimension(s).
|
||||
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
if len(rhs_group_dims) != 0:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires zero group dimensions in the rhs '
|
||||
'when lhs ragged dimension is contracting or batch.'
|
||||
)
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
if len(rhs_group_dims) != 1:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires exactly one rhs group dimension '
|
||||
'when lhs ragged dimension is noncontracting.'
|
||||
)
|
||||
rhs_group_dim = rhs_group_dims[0]
|
||||
_check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
|
||||
if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires rhs group dimension numbers to be '
|
||||
'distinct from contracting and batch dimensions.'
|
||||
)
|
||||
if rhs.shape[rhs_group_dim] != num_groups:
|
||||
raise TypeError(
|
||||
'expected rhs group dimension size to be '
|
||||
f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
|
||||
)
|
||||
|
||||
out_shape = _dot_general_shape_rule(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
)
|
||||
if mode == RaggedDotMode.RAGGED_CONTRACTING:
|
||||
out_shape = (num_groups,) + out_shape
|
||||
return out_shape
|
||||
|
||||
|
||||
def _ragged_dot_general_dtype_rule(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
) -> np.dtype:
|
||||
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
||||
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires that '
|
||||
'group_sizes.dtype is subtype of np.integer.'
|
||||
)
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
|
||||
return _dot_general_dtype_rule(
|
||||
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
out_sharding=None)
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
name='lax.ragged_dot_general',
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
primals, tangents, precision, preferred_element_type, group_offset
|
||||
def _ragged_dot_general_jvp_rule(
|
||||
primals, tangents, ragged_dot_dimension_numbers,
|
||||
precision, preferred_element_type, group_offset
|
||||
):
|
||||
# note - we could ostensibly just get this by passing on the
|
||||
# value to ragged_dot below, but, this feels cleaner.
|
||||
@ -5245,20 +5495,22 @@ def _ragged_dot_jvp_rule(
|
||||
dx, dy, _ = tangents # no tan on the gs
|
||||
|
||||
# primal
|
||||
primal_out = ragged_dot(
|
||||
primal_out = ragged_dot_general(
|
||||
x,
|
||||
y,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
# tangent
|
||||
dx_out = (
|
||||
ragged_dot(
|
||||
ragged_dot_general(
|
||||
dx,
|
||||
y,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
@ -5266,73 +5518,127 @@ def _ragged_dot_jvp_rule(
|
||||
else _zeros(primal_out)
|
||||
)
|
||||
dy_out = (
|
||||
ragged_dot(
|
||||
ragged_dot_general(
|
||||
x,
|
||||
dy,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
if type(dy) is not ad_util.Zero
|
||||
else _zeros(primal_out)
|
||||
)
|
||||
tangent_out = dx_out + dy_out
|
||||
tangent_out = add(dx_out, dy_out)
|
||||
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
def _ragged_to_dense(x, y, group_sizes):
|
||||
from jax._src.lax import control_flow # avoid circular imports
|
||||
shape = (y.shape[0], x.shape[0], x.shape[1])
|
||||
x = broadcast_in_dim(x, shape, [1, 2])
|
||||
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
|
||||
group_ends = control_flow.cumsum(group_sizes)
|
||||
group_starts = concatenate(
|
||||
[_zeros(group_sizes)[:1], group_ends[:-1]],
|
||||
dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
x = select(mask, x, _zeros(x))
|
||||
return x
|
||||
|
||||
|
||||
def _ragged_dot_transpose_rule(
|
||||
ct, *operands, precision, preferred_element_type, group_offset
|
||||
def _ragged_dot_general_transpose_rule(
|
||||
ct,
|
||||
x,
|
||||
y,
|
||||
group_sizes,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
group_offset: Array | None,
|
||||
):
|
||||
x, y, gs = operands
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError('Unimplemented group_offset support.')
|
||||
|
||||
if ad.is_undefined_primal(y):
|
||||
grad_x = None
|
||||
else:
|
||||
y_t = _matrix_transpose(y)
|
||||
grad_x = ragged_dot(
|
||||
ct,
|
||||
y_t,
|
||||
gs,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
(x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
|
||||
y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
|
||||
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
||||
y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
|
||||
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
||||
x_ndim, ragged_dot_dimension_numbers
|
||||
)
|
||||
|
||||
if ad.is_undefined_primal(x):
|
||||
grad_y = None
|
||||
else:
|
||||
y = y.aval if ad.is_undefined_primal(y) else y
|
||||
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
|
||||
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
|
||||
dimension_numbers = (([1], [1]), ([0], [0]))
|
||||
grad_y = dot_general(
|
||||
x_dense,
|
||||
ct_dense,
|
||||
dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
|
||||
f'Unimplemented {fn_name} for ragged dot general in mode '
|
||||
f'{ragged_dot_mode.name}.'
|
||||
)
|
||||
|
||||
return grad_x, grad_y, None
|
||||
# This is a hack to ensure we continue to emit the `_matrix_transpose` for the
|
||||
# grad_x case. This isn't strictly necessary since we have dot_dim_nums.
|
||||
# TODO(pravnar): Remove this once we no longer care to emit the transpose.
|
||||
_is_basic_ragged_dot = (
|
||||
x_ndim == 2
|
||||
and y_ndim == 3
|
||||
and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
|
||||
)
|
||||
|
||||
def grad_x_dims():
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = (
|
||||
ragged_dot_dimension_numbers
|
||||
if _is_basic_ragged_dot
|
||||
else RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
|
||||
lhs_ragged_dimensions=[
|
||||
len(x_batch) + x_kept.index(lhs_ragged_dim)
|
||||
],
|
||||
rhs_group_dimensions=y_group,
|
||||
)
|
||||
)
|
||||
x_contract_sorted_by_y = list(
|
||||
np.take(x_contract, np.argsort(y_contract))
|
||||
)
|
||||
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
raise unimplemented('grad_x_dims', mode)
|
||||
return dims, unsorted_axes
|
||||
|
||||
def grad_y_dims():
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
|
||||
lhs_ragged_dimensions=[lhs_ragged_dim],
|
||||
rhs_group_dimensions=[],
|
||||
)
|
||||
y_contract_sorted_by_x = list(
|
||||
np.take(y_contract, np.argsort(x_contract))
|
||||
)
|
||||
unsorted_axes = (
|
||||
list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
|
||||
)
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
raise unimplemented('grad_y_dims', mode)
|
||||
return dims, unsorted_axes
|
||||
|
||||
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
|
||||
dims, unsorted_axes = dims_fn()
|
||||
ragged_dot_general_out = ragged_dot_general(
|
||||
lhs, rhs, group_sizes, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset)
|
||||
result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
|
||||
if result.dtype != aval.dtype:
|
||||
result = _convert_element_type(result, aval.dtype, aval.weak_type)
|
||||
return result
|
||||
|
||||
x_bar = (
|
||||
None
|
||||
if ad.is_undefined_primal(y)
|
||||
else _ragged_dot_grad(ct,
|
||||
_matrix_transpose(y) if _is_basic_ragged_dot else y,
|
||||
grad_x_dims,
|
||||
x.aval)
|
||||
)
|
||||
y_bar = (
|
||||
None
|
||||
if ad.is_undefined_primal(x)
|
||||
else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
|
||||
)
|
||||
return x_bar, y_bar, None
|
||||
|
||||
|
||||
def _ragged_dot_batch_unpack_args(batched_args):
|
||||
@ -5347,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims):
|
||||
return (lbd, rbd)
|
||||
|
||||
|
||||
def _ragged_dot_invoke_prim(
|
||||
def _ragged_dot_general_invoke_prim(
|
||||
group_sizes,
|
||||
lhs,
|
||||
rhs,
|
||||
new_dimension_numbers,
|
||||
new_ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
out_sharding,
|
||||
):
|
||||
del out_sharding
|
||||
return ragged_dot(
|
||||
return ragged_dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_batch_rule(
|
||||
def _ragged_dot_general_batch_rule(
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
):
|
||||
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
|
||||
|
||||
return _dot_batch_rule(
|
||||
invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
|
||||
batched_out, result_batch_dim = _dot_batch_rule(
|
||||
_ragged_dot_batch_unpack_args,
|
||||
_ragged_dot_batch_unpack_dims,
|
||||
invoke,
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
)
|
||||
if _is_ragged_contracting(batched_args[0].ndim - 1,
|
||||
ragged_dot_dimension_numbers):
|
||||
result_batch_dim += 1
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
|
||||
_ragged_dot_dtype_rule, 'ragged_dot')
|
||||
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
|
||||
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
|
||||
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
|
||||
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
|
||||
ragged_dot_general_p = standard_primitive(
|
||||
_ragged_dot_general_shape_rule,
|
||||
_ragged_dot_general_dtype_rule,
|
||||
'ragged_dot_general',
|
||||
)
|
||||
ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
|
||||
batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
|
||||
batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
|
||||
|
||||
def _ragged_dot_impl(
|
||||
|
||||
def _ragged_dot_general_impl(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
group_offset: Array | None = None,
|
||||
@ -5410,24 +5725,100 @@ def _ragged_dot_impl(
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError("Unimplemented group_offset support.")
|
||||
|
||||
if len(lhs.shape) == 3:
|
||||
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
|
||||
else:
|
||||
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = _ragged_to_dense
|
||||
def ragged_to_dense(x: Array, gs: Array, *, dim: int):
|
||||
from jax._src.lax import control_flow # avoid circular imports
|
||||
assert gs.ndim == 1
|
||||
shape = gs.shape + x.shape
|
||||
x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
|
||||
iota = broadcasted_iota(gs.dtype, shape, dim+1)
|
||||
group_ends = control_flow.cumsum(gs)
|
||||
group_starts = concatenate(
|
||||
[_zeros(gs)[:1], group_ends[:-1]],
|
||||
dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
x = select(mask, x, _zeros(x))
|
||||
return x
|
||||
|
||||
lhs = ragged_to_dense(lhs, rhs, group_sizes)
|
||||
def batched_ragged_to_dense(dim, *x_in_axes: int):
|
||||
if not x_in_axes:
|
||||
return partial(ragged_to_dense, dim=dim)
|
||||
x_axis, *rest = x_in_axes
|
||||
decr = lambda d: d - 1 if d >= x_axis else d
|
||||
return api.vmap(
|
||||
batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
|
||||
in_axes=(x_axis, 0),
|
||||
)
|
||||
|
||||
return dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dims,
|
||||
incr = lambda dims: [d + 1 for d in dims]
|
||||
|
||||
# Expand the ragged `dim` of `x`, given its batching `axes`.
|
||||
# The group axis from `gs` becomes the outermost axis of the result.
|
||||
# Some examples:
|
||||
# x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k]
|
||||
# x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
|
||||
def expand(x, dim, gs, *axes):
|
||||
expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
|
||||
unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
|
||||
return transpose(expanded, np.argsort(unsorted_dims))
|
||||
|
||||
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
||||
lhs.ndim, ragged_dot_dimension_numbers
|
||||
)
|
||||
(l_contract, r_contract), (l_batch, r_batch) = (
|
||||
ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
)
|
||||
l_prefix = _ragged_dot_prefix_dims(
|
||||
mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
|
||||
)
|
||||
|
||||
_dot_general = partial(
|
||||
dot_general,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
# TODO(pravnar): Permit other broadcastable shapes.
|
||||
if group_sizes.ndim == 1:
|
||||
group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
|
||||
|
||||
mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
assert len(rhs_group_dims) == 1
|
||||
return _dot_general(
|
||||
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
||||
rhs,
|
||||
dimension_numbers=(
|
||||
(incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
|
||||
(incr(l_batch), r_batch),
|
||||
),
|
||||
)
|
||||
case RaggedDotMode.RAGGED_CONTRACTING:
|
||||
rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
|
||||
r_prefix = _ragged_dot_prefix_dims(
|
||||
mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
|
||||
)
|
||||
return _dot_general(
|
||||
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
||||
expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
|
||||
dimension_numbers=(
|
||||
(incr(l_contract), incr(r_contract)),
|
||||
([0] + incr(l_batch), [0] + incr(r_batch)),
|
||||
),
|
||||
)
|
||||
case RaggedDotMode.RAGGED_BATCH:
|
||||
return _dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
||||
)
|
||||
|
||||
|
||||
mlir.register_lowering(ragged_dot_general_p,
|
||||
mlir.lower_fun(_ragged_dot_general_impl,
|
||||
multiple_results=False))
|
||||
|
||||
|
||||
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,
|
||||
@ -5802,8 +6193,12 @@ def _concatenate_transpose_rule(t, *operands, dimension):
|
||||
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
||||
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
||||
if bdim is not None)
|
||||
spec = next(core.get_aval(op).sharding.spec[bdim]
|
||||
for op, bdim in zip(batched_args, batch_dims) if bdim is not None)
|
||||
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
|
||||
else broadcast(op, (size,))
|
||||
else broadcast(
|
||||
op, (size,), out_sharding=core.get_aval(op).sharding.with_spec(
|
||||
(spec, *core.get_aval(op).sharding.spec)))
|
||||
for op, bdim in zip(batched_args, batch_dims)]
|
||||
return concatenate(operands, dimension + 1), 0
|
||||
|
||||
|
@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lax.lax import _float, _complex, _int
|
||||
from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
# The following imports may be unused but they are needed to register the
|
||||
# custom call targets defined in each module.
|
||||
from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
|
||||
|
||||
def register_module_custom_calls(module):
|
||||
if hasattr(module, "registrations"):
|
||||
for platform, targets in module.registrations().items():
|
||||
for name, value, api_version in targets:
|
||||
ffi.register_ffi_target(
|
||||
name, value, platform=platform, api_version=api_version
|
||||
)
|
||||
if hasattr(module, "batch_partitionable_targets"):
|
||||
for name in module.batch_partitionable_targets():
|
||||
ffi.register_ffi_target_as_batch_partitionable(name)
|
||||
|
||||
|
||||
register_module_custom_calls(gpu_linalg)
|
||||
register_module_custom_calls(gpu_solver)
|
||||
register_module_custom_calls(gpu_sparse)
|
||||
register_module_custom_calls(lapack)
|
||||
|
||||
|
||||
# Top-level functions in alphabetical order.
|
||||
|
@ -1608,8 +1608,9 @@ def _dynamic_update_slice_sharding_rule(operand, update, *start_indices):
|
||||
if operand.sharding != update.sharding:
|
||||
raise TypeError(
|
||||
"dynamic_update_slice update sharding must be equal to operand"
|
||||
f" sharding, got update sharding {update.sharding} for operand sharding"
|
||||
f" {operand.sharding}.")
|
||||
" sharding, got update sharding"
|
||||
f" {update.str_short(mesh_axis_types=True)} for operand sharding"
|
||||
f" {operand.str_short(mesh_axis_types=True)}.")
|
||||
return operand.sharding
|
||||
|
||||
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
|
||||
|
@ -200,7 +200,7 @@ class _BaseMesh:
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
|
||||
MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]]
|
||||
|
||||
class Mesh(_BaseMesh, contextlib.ContextDecorator):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
|
@ -45,8 +45,8 @@ RealNumeric = Any # Scalar jnp array or float
|
||||
@export
|
||||
@typing.runtime_checkable
|
||||
class Initializer(Protocol):
|
||||
@staticmethod
|
||||
def __call__(key: Array,
|
||||
def __call__(self,
|
||||
key: Array,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
raise NotImplementedError
|
||||
|
@ -547,6 +547,9 @@ def _einsum(
|
||||
if not last_contraction:
|
||||
dot_general_out_sharding = None
|
||||
elif out_sharding is not None and names != result_names:
|
||||
if len(result_names) > len(out_sharding.spec):
|
||||
out_sharding = out_sharding.with_spec(
|
||||
out_sharding.spec._normalized_spec_for_aval(len(result_names)))
|
||||
spec = out_sharding.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_sharding = NamedSharding(
|
||||
|
@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
|
||||
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
|
||||
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
|
||||
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
|
||||
bfloat16 = _make_scalar_type(dtypes.bfloat16)
|
||||
float16 = _make_scalar_type(np.float16)
|
||||
float32 = single = _make_scalar_type(np.float32)
|
||||
|
@ -26,7 +26,7 @@ from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.util import safe_zip, safe_map, set_module
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
from jax.sharding import Sharding
|
||||
|
||||
@ -35,6 +35,8 @@ import numpy as np
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin
|
||||
return SingleDeviceSharding(device)
|
||||
else:
|
||||
return device
|
||||
|
||||
|
||||
@export
|
||||
def ndim(a: ArrayLike) -> int:
|
||||
"""Return the number of dimensions of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object.
|
||||
|
||||
Returns:
|
||||
An integer specifying the number of dimensions of ``a``.
|
||||
|
||||
Examples:
|
||||
Number of dimensions for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.ndim(x)
|
||||
1
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.ndim(y)
|
||||
2
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.ndim(3.14)
|
||||
0
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property:
|
||||
|
||||
>>> x.ndim
|
||||
1
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("ndim", a, emit_warning=True)
|
||||
return np.ndim(a) # NumPy dispatches to a.ndim if available.
|
||||
|
||||
|
||||
@export
|
||||
def shape(a: ArrayLike) -> tuple[int, ...]:
|
||||
"""Return the shape an array.
|
||||
|
||||
JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object.
|
||||
|
||||
Returns:
|
||||
An tuple of integers representing the shape of ``a``.
|
||||
|
||||
Examples:
|
||||
Shape for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.shape(x)
|
||||
(10,)
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.shape(y)
|
||||
(2, 3)
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.shape(3.14)
|
||||
()
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:
|
||||
|
||||
>>> x.shape
|
||||
(10,)
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("shape", a, emit_warning=True)
|
||||
return np.shape(a) # NumPy dispatches to a.shape if available.
|
||||
|
||||
|
||||
@export
|
||||
def size(a: ArrayLike, axis: int | None = None) -> int:
|
||||
"""Return number of elements along a given axis.
|
||||
|
||||
JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function
|
||||
raises a :class:`TypeError` if the input is a collection such as a list or
|
||||
tuple.
|
||||
|
||||
Args:
|
||||
a: array-like object
|
||||
axis: optional integer along which to count elements. By default, return
|
||||
the total number of elements.
|
||||
|
||||
Returns:
|
||||
An integer specifying the number of elements in ``a``.
|
||||
|
||||
Examples:
|
||||
Size for arrays:
|
||||
|
||||
>>> x = jnp.arange(10)
|
||||
>>> jnp.size(x)
|
||||
10
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> jnp.size(y)
|
||||
6
|
||||
>>> jnp.size(y, axis=1)
|
||||
3
|
||||
|
||||
This also works for scalars:
|
||||
|
||||
>>> jnp.size(3.14)
|
||||
1
|
||||
|
||||
For arrays, this can also be accessed via the :attr:`jax.Array.size` property:
|
||||
|
||||
>>> y.size
|
||||
6
|
||||
"""
|
||||
# Deprecation warning added 2025-2-20.
|
||||
check_arraylike("size", a, emit_warning=True)
|
||||
return np.size(a, axis=axis) # NumPy dispatches to a.size if available.
|
||||
|
@ -25,8 +25,8 @@ from typing import Any, Callable, Protocol, Sequence
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import api_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
@ -351,7 +351,7 @@ def _pull_block_spec(
|
||||
jaxpr.constvars,
|
||||
jaxpr.invars,
|
||||
needed_invars,
|
||||
jaxpr.eqns[:jaxpr.eqns.index(eqn)],
|
||||
jaxpr.eqns[: jaxpr.eqns.index(eqn)],
|
||||
debug_info=jaxpr.debug_info,
|
||||
)
|
||||
scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||
@ -426,6 +426,7 @@ def make_kernel_function(
|
||||
return tuple(s for s in shape if s is not None)
|
||||
|
||||
_no_aval = object()
|
||||
|
||||
def _get_block_aval(bs, aval):
|
||||
if bs is pallas_core.no_block_spec or bs is None:
|
||||
return _no_aval
|
||||
@ -441,10 +442,12 @@ def make_kernel_function(
|
||||
unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten(
|
||||
in_tree, invar_usages
|
||||
)
|
||||
|
||||
def sds_like(x):
|
||||
if x is _no_aval:
|
||||
return _no_aval
|
||||
return jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
|
||||
kernel_in_type = jax.tree.map(
|
||||
sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals)
|
||||
)
|
||||
@ -688,8 +691,10 @@ def _eltwise_eval_rule(prim, ctx, x, **params):
|
||||
|
||||
|
||||
def _eltwise_pull_rule(
|
||||
prim: core.Primitive, ctx: PullRuleContext, block_spec: pallas_core.BlockSpec,
|
||||
**params
|
||||
prim: core.Primitive,
|
||||
ctx: PullRuleContext,
|
||||
block_spec: pallas_core.BlockSpec,
|
||||
**params,
|
||||
) -> Sequence[pallas_core.BlockSpec]:
|
||||
del prim, ctx, params
|
||||
return [block_spec]
|
||||
@ -702,7 +707,9 @@ def _eltwise_usage_rule(
|
||||
return [used_out]
|
||||
|
||||
|
||||
def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.BlockSpec:
|
||||
def _bcast_block_spec(
|
||||
block_spec: pallas_core.BlockSpec, i: int
|
||||
) -> pallas_core.BlockSpec:
|
||||
def new_index_map(i, *args):
|
||||
idx = block_spec.index_map(*args)
|
||||
assert len(idx) == len(block_spec.block_shape)
|
||||
@ -710,7 +717,9 @@ def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.
|
||||
return idx
|
||||
|
||||
new_block_shape = util.tuple_update(block_spec.block_shape, i, 1)
|
||||
return pallas_core.BlockSpec(new_block_shape, functools.partial(new_index_map, i))
|
||||
return pallas_core.BlockSpec(
|
||||
new_block_shape, functools.partial(new_index_map, i)
|
||||
)
|
||||
|
||||
|
||||
def _binop_usage_rule(prim, ctx, used_out: set[Usage]):
|
||||
@ -945,7 +954,9 @@ def _dynamic_slice_rule(
|
||||
return block_indices
|
||||
|
||||
new_block_spec = pallas_core.BlockSpec(block_spec.block_shape, new_index_map)
|
||||
return [new_block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1)
|
||||
return [new_block_spec] + [pallas_core.no_block_spec] * (
|
||||
len(ctx.avals_in) - 1
|
||||
)
|
||||
|
||||
|
||||
@register_eval_rule(lax.concatenate_p)
|
||||
@ -1348,7 +1359,8 @@ def _push_block_spec_jaxpr(
|
||||
return env[atom]
|
||||
|
||||
def _write_block_spec(
|
||||
atom: core.Atom, block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec
|
||||
atom: core.Atom,
|
||||
block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec,
|
||||
):
|
||||
if isinstance(atom, core.Literal):
|
||||
return
|
||||
@ -1374,7 +1386,9 @@ def _push_block_spec_jaxpr(
|
||||
|
||||
util.safe_map(_write_block_spec, eqn.outvars, out_block_specs)
|
||||
out_block_specs = tuple(util.safe_map(_read_block_spec, jaxpr.outvars))
|
||||
valid_block_spec = [bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec][0]
|
||||
valid_block_spec = [
|
||||
bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec
|
||||
][0]
|
||||
out_block_specs = tuple(
|
||||
valid_block_spec if obs is pallas_core.no_block_spec else obs
|
||||
for obs in out_block_specs
|
||||
@ -1491,6 +1505,18 @@ def _convert_element_type_push_rule(
|
||||
return block_spec
|
||||
|
||||
|
||||
@register_push_block_spec_rule(lax.select_n_p)
|
||||
def _select_n_push_rule(
|
||||
ctx: PushRuleContext,
|
||||
*args: pallas_core.BlockSpec,
|
||||
):
|
||||
del ctx
|
||||
block_specs = [b for b in args if b is not pallas_core.no_block_spec]
|
||||
if len(block_specs) > 1:
|
||||
raise NotImplementedError('select_n with multiple inputs not supported yet')
|
||||
return block_specs[0]
|
||||
|
||||
|
||||
@register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p)
|
||||
def _custom_jvp_call_push_rule(
|
||||
ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_
|
||||
@ -1500,9 +1526,7 @@ def _custom_jvp_call_push_rule(
|
||||
|
||||
|
||||
@register_push_block_spec_rule(pjit.pjit_p)
|
||||
def _pjit_push_rule(
|
||||
ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_
|
||||
):
|
||||
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):
|
||||
assert not jaxpr.consts
|
||||
return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs)
|
||||
|
||||
|
@ -32,28 +32,42 @@ def _get_aval(x):
|
||||
return jax_core.raise_to_shaped(jax_core.get_aval(x))
|
||||
|
||||
|
||||
def fuse(f, *, physicalize: bool = False):
|
||||
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
|
||||
"""Fuses a function into a single fusable.
|
||||
|
||||
Args:
|
||||
f: The function to fuse.
|
||||
physicalize: (experimental) whether to physicalize the function.
|
||||
debug: Whether to print debug information.
|
||||
|
||||
There should be a single call to a `fusable` inside the body of `f`. `fuse`
|
||||
returns a transformed function that will fuse the surrounding computation into
|
||||
the fusable and invoke it.
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
debug_info = api_util.debug_info('fuse', f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [_get_aval(x) for x in flat_args]
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
if physicalize:
|
||||
wrapper = fusable_dtype.physicalize(wrapper)
|
||||
return wrapper
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
debug_info = api_util.debug_info("fuse", f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [_get_aval(x) for x in flat_args]
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
if debug:
|
||||
print("Jaxpr before fusion:")
|
||||
print(jaxpr)
|
||||
out_tree = out_tree_thunk()
|
||||
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
if physicalize:
|
||||
wrapper = fusable_dtype.physicalize(wrapper)
|
||||
return wrapper
|
||||
|
||||
if f is not None:
|
||||
return decorator(f)
|
||||
return decorator
|
||||
|
||||
|
||||
_fusable: dict[jax_core.Primitive, Any] = {}
|
||||
|
@ -159,6 +159,9 @@ py_library(
|
||||
":core",
|
||||
":primitives",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:source_info_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3222,6 +3222,14 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
|
||||
|
||||
|
||||
def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx):
|
||||
if not isinstance(x.type.element_type, ir.F32Type):
|
||||
raise ValueError("Only float32 is supported.")
|
||||
return tpu.reciprocal(x, approx=approx)
|
||||
|
||||
|
||||
lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule
|
||||
|
||||
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
|
||||
del ty
|
||||
(out_aval,) = ctx.avals_out
|
||||
|
@ -207,6 +207,8 @@ class BufferedRef:
|
||||
is_accumulator: whether this BufferedRef is an accumulator.
|
||||
is_input_output: whether this BufferedRef is an input/output without
|
||||
automatic accumulation.
|
||||
swap: Tracks whether the BufferedRef slots need to be swapped before next
|
||||
copy.
|
||||
"""
|
||||
spec: pl.BlockSpec # static metadata
|
||||
dtype: Any # static metadata
|
||||
@ -214,9 +216,14 @@ class BufferedRef:
|
||||
window_ref: REF | None
|
||||
accum_ref: REF | None
|
||||
current_slot: ArrayRef | None
|
||||
# TODO(ramiroleal): Unused by class. Remove argument from
|
||||
# BufferedRef instantiations.
|
||||
next_slot: ArrayRef | None
|
||||
sem_recvs: SemaphoreTuple | None
|
||||
sem_sends: SemaphoreTuple | None
|
||||
# TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
|
||||
# using this ref.
|
||||
swap: ArrayRef | None
|
||||
|
||||
def tree_flatten(self):
|
||||
return (
|
||||
@ -227,6 +234,7 @@ class BufferedRef:
|
||||
self.next_slot,
|
||||
self.sem_recvs,
|
||||
self.sem_sends,
|
||||
self.swap,
|
||||
),
|
||||
(self.spec, self.dtype, self.buffer_type),
|
||||
)
|
||||
@ -240,7 +248,7 @@ class BufferedRef:
|
||||
return BufferType
|
||||
|
||||
@classmethod
|
||||
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
|
||||
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
|
||||
"""Create a BufferedRef.
|
||||
|
||||
Args:
|
||||
@ -248,6 +256,7 @@ class BufferedRef:
|
||||
dtype: dtype for buffers.
|
||||
buffer_type: enum indicating whether this is an input, output, or in/out
|
||||
accumulator buffered reference.
|
||||
needs_swap_ref: whether a swap slots tracker needs to be allocated.
|
||||
|
||||
Returns:
|
||||
Initialized BufferedRef
|
||||
@ -271,6 +280,7 @@ class BufferedRef:
|
||||
next_slot=None,
|
||||
sem_recvs=None,
|
||||
sem_sends=None,
|
||||
swap=None,
|
||||
)
|
||||
else:
|
||||
memory_space = SMEM if spec.memory_space == SMEM else VMEM
|
||||
@ -281,7 +291,7 @@ class BufferedRef:
|
||||
window_ref=memory_space((2,) + block_shape, dtype),
|
||||
accum_ref=accum_ref,
|
||||
current_slot=SMEM((1,), jnp.int32),
|
||||
next_slot=SMEM((1,), jnp.int32),
|
||||
next_slot=None,
|
||||
sem_recvs=(
|
||||
None
|
||||
if buffer_type is BufferType.OUTPUT
|
||||
@ -292,23 +302,24 @@ class BufferedRef:
|
||||
if buffer_type is BufferType.INPUT
|
||||
else SemaphoreType.DMA((2,))
|
||||
),
|
||||
swap=SMEM((1,), jnp.bool) if needs_swap_ref else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def input(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.INPUT)
|
||||
def input(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def output(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.OUTPUT)
|
||||
def output(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def accumulator(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
|
||||
def accumulator(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def input_output(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT)
|
||||
def input_output(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref)
|
||||
|
||||
@property
|
||||
def block_shape(self):
|
||||
@ -329,7 +340,7 @@ class BufferedRef:
|
||||
if self.memory_space == VMEM:
|
||||
return self.window_ref.at[buffer_slice]
|
||||
else:
|
||||
return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
|
||||
return self.window_ref.at[(self.current_slot_index, *buffer_slice)]
|
||||
|
||||
@property
|
||||
def is_input(self):
|
||||
@ -355,6 +366,14 @@ class BufferedRef:
|
||||
def is_input_output(self):
|
||||
return self.buffer_type == BufferType.INPUT_OUTPUT
|
||||
|
||||
@property
|
||||
def current_slot_index(self):
|
||||
return self.current_slot[0]
|
||||
|
||||
@property
|
||||
def next_slot_index(self):
|
||||
return lax.rem(self.current_slot_index + 1, 2)
|
||||
|
||||
def bind_existing_ref(self, window_ref, indices):
|
||||
"""For handling VMEM references, the pipeline aliases the existing ref."""
|
||||
if self.memory_space == VMEM:
|
||||
@ -373,12 +392,15 @@ class BufferedRef:
|
||||
"""Initialize slot indices."""
|
||||
if self.memory_space == VMEM: return
|
||||
self.current_slot[0] = 0
|
||||
self.next_slot[0] = 0
|
||||
if self.swap is not None:
|
||||
self.swap[0] = False
|
||||
|
||||
def swap_slots(self):
|
||||
"""Switch to the next slot."""
|
||||
if self.memory_space == VMEM: return
|
||||
self.current_slot[0] = self.next_slot[0]
|
||||
self.current_slot[0] = self.next_slot_index
|
||||
if self.swap is not None:
|
||||
self.swap[0] = False
|
||||
|
||||
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
|
||||
# We need to handle blocks that might go OOB in the src array. An in bounds
|
||||
@ -441,8 +463,9 @@ class BufferedRef:
|
||||
"""Starts copy of HBM dma slice into the current slot."""
|
||||
assert self.is_input
|
||||
if self.memory_space == VMEM: return
|
||||
next_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
self.next_slot[0] = next_slot
|
||||
if self.swap is not None:
|
||||
self.swap[0] = True
|
||||
next_slot = self.next_slot_index
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -455,8 +478,9 @@ class BufferedRef:
|
||||
"""Starts copy of HBM dma slice from the current slot."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
slot = self.current_slot[0]
|
||||
self.next_slot[0] = lax.rem(slot + 1, 2)
|
||||
if self.swap is not None:
|
||||
self.swap[0] = True
|
||||
slot = self.current_slot_index
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -471,7 +495,7 @@ class BufferedRef:
|
||||
if self.memory_space == VMEM: return
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
current_slot = self.current_slot[0]
|
||||
current_slot = self.current_slot_index
|
||||
tpu_primitives.make_async_copy(
|
||||
src_ref.at[src_slice], # nb: doesn't matter
|
||||
self.window_ref.at[current_slot].at[
|
||||
@ -484,7 +508,8 @@ class BufferedRef:
|
||||
"""Waits for output copy to finish."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
# In a double buffer, previous slot is the same as next slot.
|
||||
prev_slot = self.next_slot_index
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -671,10 +696,7 @@ class Scheduler:
|
||||
def _start():
|
||||
if buffered_ref.is_input:
|
||||
buffered_ref.copy_in(src_ref, self.indices)
|
||||
|
||||
# In the prologue this makes it so we wait on the prologue copy to finish.
|
||||
# In other iterations this is the regular swap.
|
||||
buffered_ref.swap_slots()
|
||||
buffered_ref.swap_slots()
|
||||
|
||||
def wait_in(self, buffered_ref, src_ref, schedule=None):
|
||||
if schedule is None:
|
||||
@ -780,9 +802,32 @@ class Scheduler:
|
||||
@self._named_scope("ep_finalize")
|
||||
def _end():
|
||||
if buffered_ref.is_output:
|
||||
buffered_ref.swap_slots() # formally correct, not actually necessary.
|
||||
buffered_ref.wait_out(dst_ref, self.indices)
|
||||
|
||||
def swap_slots(self, buffered_ref, hbm_ref, schedule=None):
|
||||
if buffered_ref.swap is not None:
|
||||
swap = buffered_ref.swap[0]
|
||||
else:
|
||||
# If we are not using an SMEM `swap` tensor to keep track of
|
||||
# swaps needed, then all the copies into and out of BufferedRefs
|
||||
# are done by direct calls to the `copy_in` and `copy_out`
|
||||
# methods in the pipeline loop. To determine if the BufferedRef
|
||||
# needs a swap of slots, we recalculate the copy-in/copy-out
|
||||
# conditions.
|
||||
if schedule is None:
|
||||
schedule = _default_schedule
|
||||
pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref)
|
||||
pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref)
|
||||
|
||||
copied_in = pred_in & buffered_ref.is_input & ~self.last_step
|
||||
copied_out = pred_out & buffered_ref.is_output
|
||||
swap = copied_in | copied_out
|
||||
|
||||
@pl.when(swap)
|
||||
@self._named_scope("ep_swap")
|
||||
def _swap():
|
||||
buffered_ref.swap_slots()
|
||||
|
||||
# END SCHEDULE --------------------------------------------------------------
|
||||
|
||||
|
||||
@ -875,6 +920,7 @@ def make_pipeline_allocations(
|
||||
in_specs=None,
|
||||
out_specs=None,
|
||||
should_accumulate_out=False,
|
||||
needs_swap_ref=True,
|
||||
):
|
||||
"""Create BufferedRefs for the pipeline.
|
||||
|
||||
@ -887,6 +933,7 @@ def make_pipeline_allocations(
|
||||
out_specs: output pallas block specs
|
||||
should_accumulate_out: booleans to indicate which outputs should be treated
|
||||
as accumulators.
|
||||
needs_swap_ref: whether a swap slots tracker needs to be allocated.
|
||||
|
||||
Returns:
|
||||
A list of BufferedRefs, one corresponding to each ref specified in the
|
||||
@ -905,12 +952,12 @@ def make_pipeline_allocations(
|
||||
in_refs = refs[:num_in_specs]
|
||||
out_refs = refs[num_in_specs:]
|
||||
def make_input_bref(in_spec, in_ref):
|
||||
return BufferedRef.input(in_spec, in_ref.dtype)
|
||||
return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref)
|
||||
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
|
||||
def make_output_bref(out_spec, out_ref, accumulate):
|
||||
if accumulate:
|
||||
return BufferedRef.accumulator(out_spec, out_ref.dtype)
|
||||
return BufferedRef.output(out_spec, out_ref.dtype)
|
||||
return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref)
|
||||
return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref)
|
||||
out_brefs = jax.tree.map(
|
||||
make_output_bref, out_specs, out_refs, should_accumulate_out)
|
||||
return (*in_brefs, *out_brefs)
|
||||
@ -1109,6 +1156,14 @@ def emit_pipeline(
|
||||
scratches = ()
|
||||
if allocations is None:
|
||||
# run with inline scoped allocations
|
||||
|
||||
# Prefetch and postyeet are arbitrary functions that can copy
|
||||
# into or out of any of the BufferedRefs. Thus, we need a ref
|
||||
# for the scheduler to mark when the prefetch or postyeet
|
||||
# functions perform a copy and the slots need to be
|
||||
# swapped. Without prefetch and postyeet, the swapping logic can
|
||||
# be performed without the need for state.
|
||||
needs_swap_ref = prefetch is not None or postyeet is not None
|
||||
return primitives.run_scoped(
|
||||
lambda allocations: pipeline(
|
||||
*refs,
|
||||
@ -1125,7 +1180,9 @@ def emit_pipeline(
|
||||
*refs,
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
should_accumulate_out=should_accumulate_out),
|
||||
should_accumulate_out=should_accumulate_out,
|
||||
needs_swap_ref=needs_swap_ref,
|
||||
),
|
||||
)
|
||||
if isinstance(allocations, list):
|
||||
allocations = tuple(allocations)
|
||||
@ -1184,6 +1241,8 @@ def emit_pipeline(
|
||||
lax.cond(step == 0,
|
||||
lambda: postyeet(*brefs, scheduler),
|
||||
lambda: None)
|
||||
|
||||
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return _next_index(indices, grid)
|
||||
|
@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence
|
||||
from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -25,6 +25,7 @@ import math
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
@ -36,6 +37,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.mlir.dialects import math as math_dialect
|
||||
from jax._src.lib.mlir.dialects import memref as memref_dialect
|
||||
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.lib.mlir.dialects import scf as scf_dialect
|
||||
@ -837,6 +839,29 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value:
|
||||
)
|
||||
|
||||
|
||||
def _lower_fun(
|
||||
fun: Callable[..., Any], *, multiple_results: bool
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
def lowering_rule(ctx: LoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(
|
||||
fun
|
||||
if multiple_results
|
||||
else lambda *args, **params: (fun(*args, **params),),
|
||||
params,
|
||||
debug_info=api_util.debug_info(
|
||||
"Pallas Mosaic GPU lower_fun", fun, args, params
|
||||
),
|
||||
)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
out = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, args, consts
|
||||
)
|
||||
return out if multiple_results else out[0]
|
||||
|
||||
return lowering_rule
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
@ -1162,6 +1187,9 @@ def _convert_element_type_lowering_rule_wg(
|
||||
cur_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
|
||||
new_dtype = mgpu_utils.dtype_to_ir_type(new_dtype)
|
||||
|
||||
if cur_dtype == new_dtype:
|
||||
return x
|
||||
|
||||
if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8:
|
||||
raise NotImplementedError("Conversion involving sub-byte types unsupported")
|
||||
|
||||
@ -1170,7 +1198,29 @@ def _convert_element_type_lowering_rule_wg(
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
to_integer = ir.IntegerType.isinstance(new_dtype)
|
||||
if from_float and to_float:
|
||||
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
||||
cur_ty_width = ir.FloatType(cur_dtype).width
|
||||
new_ty_width = ir.FloatType(new_dtype).width
|
||||
if cur_ty_width == new_ty_width:
|
||||
# There is no instruction to perform conversions between two float types
|
||||
# of the same width. Go through the next-larger standard type.
|
||||
# TODO(bchetioui): support conversions between float types of width 8.
|
||||
# Which larger type to pick will depend on the number of bits in the
|
||||
# smallest exponent.
|
||||
if cur_ty_width != 16:
|
||||
raise NotImplementedError(
|
||||
"Conversion between float types of width other than 16 not"
|
||||
" supported"
|
||||
)
|
||||
larger_ty = ir.F32Type.get()
|
||||
if x_aval.shape:
|
||||
upcast_ty = ir.VectorType.get(x_aval.shape, larger_ty)
|
||||
else:
|
||||
upcast_ty = larger_ty
|
||||
|
||||
def convert(ty, x):
|
||||
return arith_dialect.truncf(ty, arith_dialect.extf(upcast_ty, x))
|
||||
|
||||
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
||||
convert = arith_dialect.truncf
|
||||
else:
|
||||
convert = arith_dialect.extf
|
||||
@ -1190,10 +1240,26 @@ def _convert_element_type_lowering_rule_wg(
|
||||
else:
|
||||
convert = arith_dialect.uitofp
|
||||
elif from_float and to_integer:
|
||||
dst_width = mgpu_utils.bitwidth(new_dtype)
|
||||
# We clamp the float value to the min/max integer destination value
|
||||
# in order to match JAX/XLA casting behavior. Note that this differs
|
||||
# from numpy casting behavior.
|
||||
if mgpu_utils.is_signed(y_aval.dtype):
|
||||
maxint = 2 ** (dst_width - 1) - 1
|
||||
minint = -(2 ** (dst_width - 1))
|
||||
convert = arith_dialect.fptosi
|
||||
else:
|
||||
maxint = 2**dst_width - 1
|
||||
minint = 0
|
||||
convert = arith_dialect.fptoui
|
||||
|
||||
maxint = _ir_constant(maxint, cur_dtype)
|
||||
minint = _ir_constant(minint, cur_dtype)
|
||||
if x_aval.shape:
|
||||
maxint = vector_dialect.splat(x.type, maxint)
|
||||
minint = vector_dialect.splat(x.type, minint)
|
||||
x = arith_dialect.minimumf(x, maxint)
|
||||
x = arith_dialect.maximumf(x, minint)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
|
||||
|
||||
@ -1206,6 +1272,13 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
|
||||
lax.not_p: lambda ctx, x: ~x,
|
||||
})
|
||||
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({
|
||||
lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False),
|
||||
lax.not_p: _lower_fun(
|
||||
lambda x: jnp.bitwise_xor(x, -1), multiple_results=False
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
|
||||
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
@ -1342,48 +1415,98 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
|
||||
|
||||
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
if y == 2:
|
||||
return x * x
|
||||
return NotImplementedError
|
||||
if y != 2:
|
||||
raise NotImplementedError
|
||||
return _square_lowering_rule(ctx, x)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
return x * x
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
return x * x
|
||||
if jnp.issubdtype(x_aval.dtype, jnp.integer):
|
||||
return arith_dialect.muli(x, x)
|
||||
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
||||
return arith_dialect.mulf(x, x)
|
||||
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
|
||||
|
||||
|
||||
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
|
||||
fastmath = (
|
||||
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
|
||||
)
|
||||
return math_dialect.rsqrt(
|
||||
_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
|
||||
fastmath = (
|
||||
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
|
||||
)
|
||||
return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane)
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
|
||||
def _logistic(x):
|
||||
return 1.0 / (1 + lax.exp(-x))
|
||||
|
||||
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun(
|
||||
_logistic, multiple_results=False
|
||||
)
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = (
|
||||
_lower_fun(_logistic, multiple_results=False)
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return a.exp(approx=ctx.module_ctx.approx_math)
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math)
|
||||
fastmath = (
|
||||
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
|
||||
)
|
||||
return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return a.exp2(approx=ctx.module_ctx.approx_math)
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math)
|
||||
fastmath = (
|
||||
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
|
||||
)
|
||||
return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math)
|
||||
fastmath = (
|
||||
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
|
||||
)
|
||||
return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane)
|
||||
@ -1484,11 +1607,7 @@ def _debug_print_lowering_rule(
|
||||
)
|
||||
elif len(ctx.avals_in) == 1:
|
||||
[arg] = args
|
||||
@arg.foreach
|
||||
def _(val, idx):
|
||||
idx_fmt = ", ".join(["{}"] * len(idx))
|
||||
fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}")
|
||||
mgpu.debug_print(fmt_str, *idx, val, uniform=False)
|
||||
arg.debug_print(fmt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"debug_print only supports printing of scalar values, or a single array"
|
||||
@ -1901,27 +2020,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
|
||||
|
||||
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(
|
||||
lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup
|
||||
)
|
||||
def _bitcast_convert_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, operand, *, new_dtype
|
||||
ctx: LoweringRuleContext, x, *, new_dtype
|
||||
):
|
||||
# TODO(petebu) Handle case where src and dst types have different bitwidths
|
||||
[operand_aval] = ctx.avals_in
|
||||
operand = _ensure_fa(operand, operand_aval.dtype)
|
||||
src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
|
||||
[x_aval] = ctx.avals_in
|
||||
src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
|
||||
dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
|
||||
assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
|
||||
assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
|
||||
if src_elem_type.width != dst_elem_type.width:
|
||||
raise NotImplementedError(
|
||||
f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
|
||||
f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they"
|
||||
" have different widths"
|
||||
)
|
||||
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
|
||||
x = _ensure_ir_value(x, x_aval.dtype)
|
||||
return arith_dialect.bitcast(
|
||||
ir.VectorType.get(x_aval.shape, dst_elem_type), x
|
||||
)
|
||||
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
if ir.IntegerType.isinstance(dst_elem_type):
|
||||
output_is_signed = mgpu_utils.is_signed(new_dtype)
|
||||
else:
|
||||
output_is_signed = None
|
||||
return mgpu.FragmentedArray.bitcast(
|
||||
operand, dst_elem_type, output_is_signed=output_is_signed
|
||||
x, dst_elem_type, output_is_signed=output_is_signed
|
||||
)
|
||||
|
||||
|
||||
|
@ -91,6 +91,7 @@ class BufferedRef:
|
||||
self.smem_ref.at[slot], # pytype: disable=unsupported-operands
|
||||
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
|
||||
predicate=predicate,
|
||||
commit_group=False,
|
||||
)
|
||||
|
||||
|
||||
@ -299,6 +300,8 @@ def emit_pipeline(
|
||||
predicate=lax.bitwise_or(slices_changed, is_last_step),
|
||||
)
|
||||
|
||||
gpu_primitives.commit_smem_to_gmem_group()
|
||||
|
||||
fetch_step = step + (max_concurrent_steps - delay_release)
|
||||
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
|
||||
|
||||
@ -344,6 +347,8 @@ def emit_pipeline(
|
||||
if bref.is_index_invariant:
|
||||
bref.copy_out(last_slot, last_indices, predicate=None)
|
||||
|
||||
gpu_primitives.commit_smem_to_gmem_group()
|
||||
|
||||
# Finalize the pipeline.
|
||||
gpu_primitives.wait_smem_to_gmem(0)
|
||||
|
||||
@ -578,6 +583,7 @@ def emit_pipeline_warp_specialized(
|
||||
bref.copy_out(_get_slot(slot, ~bref.is_index_invariant),
|
||||
indices,
|
||||
predicate=slices_changed)
|
||||
gpu_primitives.commit_smem_to_gmem_group()
|
||||
next_indices = _inc_grid_by_1(indices, grid)
|
||||
return (next_indices, new_store_slices, next_body_carry)
|
||||
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
|
||||
@ -619,6 +625,8 @@ def emit_pipeline_warp_specialized(
|
||||
if bref.is_index_invariant:
|
||||
bref.copy_out(last_slot, last_indices, predicate=None)
|
||||
|
||||
gpu_primitives.commit_smem_to_gmem_group()
|
||||
|
||||
# Finalize the pipeline.
|
||||
gpu_primitives.wait_smem_to_gmem(0)
|
||||
|
||||
|
@ -86,6 +86,7 @@ def _copy_smem_to_gmem_lowering(
|
||||
src_transforms_treedef,
|
||||
dst_transforms_treedef,
|
||||
has_user_predicate,
|
||||
commit_group,
|
||||
):
|
||||
predicate = ctx.module_ctx.single_wg_lane_predicate
|
||||
if has_user_predicate:
|
||||
@ -106,6 +107,7 @@ def _copy_smem_to_gmem_lowering(
|
||||
src_ref=src,
|
||||
dst_ref=dst,
|
||||
predicate=predicate,
|
||||
arrive=commit_group,
|
||||
**copy_params,
|
||||
)
|
||||
return ()
|
||||
@ -119,7 +121,12 @@ def _copy_smem_to_gmem_lowering(
|
||||
assert copy_params.get("swizzle") is None
|
||||
assert not copy_params.get("gmem_transform")
|
||||
mgpu.dialect.async_store(
|
||||
src, dst, indices, slice_lengths, predicate=predicate
|
||||
src,
|
||||
dst,
|
||||
indices,
|
||||
slice_lengths,
|
||||
predicate=predicate,
|
||||
commit_group=commit_group, # type: ignore[call-arg]
|
||||
)
|
||||
return ()
|
||||
|
||||
@ -174,7 +181,11 @@ def _extract_smem_copy_params(transforms):
|
||||
|
||||
|
||||
def copy_smem_to_gmem(
|
||||
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
|
||||
src: _Ref,
|
||||
dst: _Ref,
|
||||
predicate: jax.Array | None = None,
|
||||
*,
|
||||
commit_group: bool = True,
|
||||
) -> None:
|
||||
"""Asynchronously copies a SMEM reference to a GMEM reference.
|
||||
|
||||
@ -183,6 +194,9 @@ def copy_smem_to_gmem(
|
||||
dst: The GMEM reference to copy to.
|
||||
predicate: A boolean indicating whether the copy should be performed. If
|
||||
``None``, the copy is always performed.
|
||||
commit_group: If ``True``, this and any previously uncommitted copies
|
||||
are committed to a group and can be awaited jointly via
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`.
|
||||
|
||||
See also:
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
|
||||
@ -209,6 +223,7 @@ def copy_smem_to_gmem(
|
||||
src_transforms_treedef=src_transforms_treedef,
|
||||
dst_transforms_treedef=dst_transforms_treedef,
|
||||
has_user_predicate=predicate is not None,
|
||||
commit_group=commit_group,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -475,6 +490,28 @@ def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None:
|
||||
wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only)
|
||||
|
||||
|
||||
commit_group_p = jax_core.Primitive("commit_group")
|
||||
commit_group_p.multiple_results = True
|
||||
|
||||
|
||||
@commit_group_p.def_effectful_abstract_eval
|
||||
def _commit_group_abstract_eval():
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _commit_group_lowering(ctx: lowering.LoweringRuleContext):
|
||||
del ctx # Unused.
|
||||
nvvm_dialect.cp_async_bulk_commit_group()
|
||||
return ()
|
||||
|
||||
|
||||
def commit_smem_to_gmem_group() -> None:
|
||||
"""Commits all issued but uncommited SMEM->GMEM copies to a group."""
|
||||
commit_group_p.bind()
|
||||
|
||||
|
||||
# WGMMA on an accumulator reference
|
||||
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
|
||||
wgmma_ref_p.multiple_results = True
|
||||
|
@ -695,14 +695,44 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
|
||||
if precision is not None:
|
||||
raise ValueError("Only one of allow_tf32 and precision can be specified")
|
||||
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
|
||||
dtype = jnp.promote_types(a.dtype, b.dtype)
|
||||
out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32
|
||||
return jax.lax.dot_general(
|
||||
a,
|
||||
b,
|
||||
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
|
||||
precision=precision,
|
||||
preferred_element_type=jnp.float32,
|
||||
preferred_element_type=out_dtype,
|
||||
)
|
||||
|
||||
reciprocal_p = jax_core.Primitive("reciprocal")
|
||||
|
||||
|
||||
def reciprocal(x, *, approx=False):
|
||||
return reciprocal_p.bind(x, approx=approx)
|
||||
|
||||
|
||||
@reciprocal_p.def_abstract_eval
|
||||
def _reciprocal_abstract_eval(x, *, approx):
|
||||
del approx
|
||||
return x
|
||||
|
||||
|
||||
def _reciprocal_lowering_rule(
|
||||
ctx: mlir.LoweringRuleContext, x, *, approx=False
|
||||
):
|
||||
def _reciprocal(x, *, approx=False):
|
||||
if approx:
|
||||
return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32)
|
||||
return jnp.reciprocal(x)
|
||||
|
||||
return mlir.lower_fun(_reciprocal, multiple_results=False)(
|
||||
ctx, x, approx=approx
|
||||
)
|
||||
|
||||
|
||||
mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule)
|
||||
|
||||
|
||||
class PrintEffect(effects.Effect):
|
||||
__str__ = lambda self: "Print"
|
||||
|
@ -12,35 +12,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import importlib.util
|
||||
|
||||
__all__ = ["Path"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
epath_installed: bool
|
||||
|
||||
class PathProtocol(Protocol):
|
||||
"""A factory that creates a PurePath."""
|
||||
def __call__(self, *pathsegments: str | os.PathLike) -> pathlib.Path:
|
||||
...
|
||||
|
||||
Path: PathProtocol
|
||||
|
||||
# If etils.epath (aka etils[epath] to pip) is present, we prefer it because it
|
||||
# can read and write to, e.g., GCS buckets. Otherwise we use the builtin
|
||||
# pathlib and can only read/write to the local filesystem.
|
||||
epath_installed = bool(
|
||||
importlib.util.find_spec("etils") and
|
||||
importlib.util.find_spec("etils.epath")
|
||||
)
|
||||
if epath_installed:
|
||||
logger.debug("etils.epath found. Using etils.epath for file I/O.")
|
||||
|
||||
def __dir__():
|
||||
return ["Path"]
|
||||
|
||||
def __getattr__(name):
|
||||
if name != "Path":
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}")
|
||||
|
||||
global Path
|
||||
from etils import epath
|
||||
Path = epath.Path
|
||||
return Path
|
||||
else:
|
||||
try:
|
||||
from etils import epath # type: ignore
|
||||
except ImportError:
|
||||
logger.debug("etils.epath was not found. Using pathlib for file I/O.")
|
||||
Path = pathlib.Path
|
||||
epath_installed = False
|
||||
else:
|
||||
logger.debug("etils.epath found. Using etils.epath for file I/O.")
|
||||
# Ultimately, epath.Path implements pathlib.Path. See:
|
||||
# https://github.com/google/etils/blob/2083f3d932a88d8a135ef57112cd1f9aff5d559e/etils/epath/abstract_path.py#L47
|
||||
Path = epath.Path
|
||||
epath_installed = True
|
||||
|
157
jax/_src/pjit.py
157
jax/_src/pjit.py
@ -29,7 +29,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -199,10 +198,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
|
||||
profiler = None
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, p.arg_names)
|
||||
fun_name, fails, args_flat, 'jit', p.arg_names)
|
||||
raise ValueError(msg) from None
|
||||
except xla.InvalidInputException as e:
|
||||
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
|
||||
@ -359,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env,
|
||||
compiler_options_kvs=jit_info.compiler_options_kvs)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
@ -546,8 +543,7 @@ class PjitParams(NamedTuple):
|
||||
def _infer_params_impl(
|
||||
fun: Callable,
|
||||
ji: PjitInfo,
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
ctx_mesh: mesh_lib.Mesh | None,
|
||||
dbg: core.DebugInfo,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
@ -559,8 +555,8 @@ def _infer_params_impl(
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
if pjit_mesh is not None:
|
||||
if (ji.backend or ji.device) and not pjit_mesh.empty:
|
||||
if ctx_mesh is not None:
|
||||
if (ji.backend or ji.device) and not ctx_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
@ -591,13 +587,12 @@ def _infer_params_impl(
|
||||
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
|
||||
in_shardings_treedef = out_shardings_treedef = treedef
|
||||
else:
|
||||
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
|
||||
in_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
|
||||
_create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit')
|
||||
for x in ji.in_shardings_leaves)
|
||||
in_shardings_treedef = ji.in_shardings_treedef
|
||||
out_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
|
||||
_create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit')
|
||||
for x in ji.out_shardings_leaves)
|
||||
out_shardings_treedef = ji.out_shardings_treedef
|
||||
|
||||
@ -655,8 +650,8 @@ def _infer_params_impl(
|
||||
out_shardings=out_shardings_flat,
|
||||
in_layouts=in_layouts_flat,
|
||||
out_layouts=out_layouts_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=fun_qual_name(flat_fun),
|
||||
keep_unused=ji.keep_unused,
|
||||
inline=ji.inline,
|
||||
@ -686,38 +681,30 @@ def _infer_params_cached(
|
||||
jit_info: PjitInfo,
|
||||
signature: jax_jit.ArgumentSignature,
|
||||
in_avals: tuple[core.AbstractValue, ...],
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
ctx_mesh: mesh_lib.Mesh | None,
|
||||
) -> InferParamsCacheEntry:
|
||||
return InferParamsCacheEntry()
|
||||
|
||||
def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together():
|
||||
if (not mesh_lib.thread_resources.env.physical_mesh.empty and
|
||||
mesh_lib.get_concrete_mesh() is not None):
|
||||
raise ValueError(
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed.')
|
||||
|
||||
def _infer_params(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
if ji.use_resource_env:
|
||||
# We need to fetch the mesh from inside the wrapped function, because
|
||||
# meshes are dynamically scoped (i.e., with a context manager).
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
|
||||
def _infer_params_internal(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
ctx_mesh = mesh_lib.get_concrete_mesh()
|
||||
dbg = debug_info(
|
||||
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
|
||||
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
|
||||
signature=ji.fun_signature)
|
||||
|
||||
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
|
||||
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg,
|
||||
p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg,
|
||||
args, kwargs, in_avals=None)
|
||||
return p, p.consts + args_flat
|
||||
|
||||
@ -725,10 +712,11 @@ def _infer_params(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
avals = _infer_input_type(fun, dbg, dynargs)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh)
|
||||
|
||||
if entry.pjit_params is None:
|
||||
p, args_flat = _infer_params_impl(
|
||||
fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals)
|
||||
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
|
||||
if p.attrs_tracked: # if attrs, don't popoulate the cache
|
||||
return p, p.consts + args_flat
|
||||
entry.pjit_params = p
|
||||
@ -1619,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
|
||||
def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
lowering_platforms, lowering_parameters, pgle_profiler,
|
||||
compiler_options_kvs):
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings)
|
||||
@ -1627,8 +1615,8 @@ def _resolve_and_lower(
|
||||
jaxpr.in_avals)
|
||||
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
|
||||
return _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline, compiler_options_kvs,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
@ -1637,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
pgle_compile_options, pgle_profiler = {}, None
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
@ -1662,8 +1650,8 @@ def _pjit_call_impl_python(
|
||||
compiled = _resolve_and_lower(
|
||||
args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
out_layouts=out_layouts, donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
|
||||
inline=inline, lowering_platforms=None,
|
||||
lowering_parameters=mlir.LoweringParameters(),
|
||||
pgle_profiler=pgle_profiler,
|
||||
@ -1694,7 +1682,7 @@ def _pjit_call_impl_python(
|
||||
|
||||
@weakref_lru_cache
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name,
|
||||
out_layouts, donated_invars, ctx_mesh, name,
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
|
||||
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
|
||||
@ -1708,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def call_impl_cache_miss(*args_, **kwargs_):
|
||||
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
|
||||
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
out_layouts=out_layouts, donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
|
||||
inline=inline, compiler_options_kvs=compiler_options_kvs)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
@ -1724,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs)
|
||||
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
@ -1733,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
||||
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
||||
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
||||
use_resource_env=resource_env is not None)
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], cache_key,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
@ -1749,8 +1736,8 @@ def _pjit_lower(
|
||||
out_shardings,
|
||||
in_layouts: pxla.MaybeLayout,
|
||||
out_layouts: pxla.MaybeLayout,
|
||||
resource_env,
|
||||
donated_invars,
|
||||
ctx_mesh,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
@ -1760,14 +1747,10 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if resource_env is not None:
|
||||
mesh, api_name = resource_env.physical_mesh, 'pjit'
|
||||
else:
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
jaxpr, 'jit', name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused, context_mesh=mesh,
|
||||
keep_unused=keep_unused, context_mesh=ctx_mesh,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
@ -1919,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
|
||||
|
||||
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline, compiler_options_kvs):
|
||||
out_shardings, in_layouts, out_layouts, donated_invars,
|
||||
ctx_mesh, keep_unused, inline, compiler_options_kvs):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
@ -1929,7 +1912,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
func = _pjit_cached_lower_jaxpr_to_fun(
|
||||
ctx, name, jaxpr, tuple(effects), in_shardings,
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name=('jit' if resource_env is None else 'pjit'))
|
||||
api_name='jit')
|
||||
|
||||
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
|
||||
args = (*ctx.dim_var_values, *tokens_in, *args)
|
||||
@ -1950,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in,
|
||||
dims_in: tuple[int, ...],
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
|
||||
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
else:
|
||||
mesh = None
|
||||
|
||||
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
|
||||
in_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim)
|
||||
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
|
||||
aval.ndim)
|
||||
if axis_in is not None else i
|
||||
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
|
||||
out_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim)
|
||||
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh,
|
||||
aval.ndim)
|
||||
if axis_out is not None else o
|
||||
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
@ -1982,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2005,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val):
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Sharding | UnspecifiedValue,
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
|
||||
ndim: int):
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
|
||||
mesh, ndim: int):
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
return s
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim)
|
||||
@ -2045,20 +2025,8 @@ def _pjit_batcher_for_sharding(
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
||||
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
|
||||
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
|
||||
primals_in = [*primals_in, *mut_primals]
|
||||
tangents_in = [*tangents_in, *mut_tangents]
|
||||
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals)
|
||||
in_layouts = (*in_layouts,) + (None,) * len(mut_primals)
|
||||
donated_invars = (*donated_invars,) + (False,) * len(mut_primals)
|
||||
|
||||
tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x
|
||||
for x, a in zip(tangents_in, jaxpr.in_avals)]
|
||||
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
@ -2074,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
|
||||
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
|
||||
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2091,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
|
||||
# constvars will become residuals. Move them to the end of the ordinary args.
|
||||
@ -2107,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
out_shardings=_filter_zeros(nzs_out, out_shardings),
|
||||
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
|
||||
out_layouts=_filter_zeros(nzs_out, out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2127,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
out_shardings=(*res_shardings, *out_shardings),
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=(*res_layouts, *out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2143,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization
|
||||
def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
*in_tracers,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
in_layouts, out_layouts, donated_invars, ctx_mesh,
|
||||
name, keep_unused, inline, compiler_options_kvs):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
@ -2210,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
||||
out_shardings=known_out_shardings,
|
||||
in_layouts=keep_where(in_layouts, known_ins),
|
||||
out_layouts=known_out_layouts, resource_env=resource_env,
|
||||
out_layouts=known_out_layouts,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name, keep_unused=keep_unused, inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||
@ -2242,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
out_shardings=keep_where(out_shardings, unknown_outs),
|
||||
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
|
||||
out_layouts=keep_where(out_layouts, unknown_outs),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||
(False,) * num_residuals),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2330,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun,
|
||||
def _pjit_transpose(cts_in, *primals_in,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
@ -2379,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
out_shardings=transpose_out_shardings,
|
||||
in_layouts=transpose_in_layouts,
|
||||
out_layouts=transpose_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2464,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn,
|
||||
del params['out_layouts']
|
||||
if not params['keep_unused']:
|
||||
del params['keep_unused']
|
||||
if (params['resource_env'] is None or
|
||||
params['resource_env'].physical_mesh.empty):
|
||||
del params['resource_env']
|
||||
if params['ctx_mesh'] is None or params['ctx_mesh'].empty:
|
||||
del params['ctx_mesh']
|
||||
if not params['compiler_options_kvs']:
|
||||
del params['compiler_options_kvs']
|
||||
|
||||
@ -2536,6 +2504,11 @@ def with_sharding_constraint(x, shardings):
|
||||
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
|
||||
of how to use this function, see `Distributed arrays and automatic parallelization`_.
|
||||
|
||||
Inside of a jitted computation, with_sharding_constraint makes it possible to
|
||||
constrain intermediate values to an uneven sharding. However, if such an
|
||||
unevenly sharded value is output by the jitted computation, it will come out
|
||||
as fully replicated, no matter the sharding annotation given.
|
||||
|
||||
Args:
|
||||
x: PyTree of jax.Arrays which will have their shardings constrained
|
||||
shardings: PyTree of sharding specifications. Valid values are the same as for
|
||||
@ -2561,8 +2534,6 @@ def with_sharding_constraint(x, shardings):
|
||||
flatten_axes("with_sharding_constraint layouts", tree, layouts))
|
||||
del layouts
|
||||
|
||||
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
|
||||
|
||||
context_mesh = (
|
||||
mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None
|
||||
else mesh_lib.thread_resources.env.physical_mesh)
|
||||
|
@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None:
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
|
||||
def is_python_scalar(val):
|
||||
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
|
||||
@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
|
||||
|
||||
def maybe_upcast(x):
|
||||
if x.dtype in custom_float_dtypes:
|
||||
|
@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents):
|
||||
return expi(x), jnp.exp(x) / x * x_dot
|
||||
|
||||
|
||||
def _expn1(n: Array, x: Array) -> Array:
|
||||
def _expn1(x: Array, n: Array) -> Array:
|
||||
# exponential integral En
|
||||
_c = _lax_const
|
||||
MACHEP = jnp.finfo(x.dtype).eps
|
||||
@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array:
|
||||
return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]
|
||||
|
||||
|
||||
def _expn2(n: Array, x: Array) -> Array:
|
||||
def _expn2(x: Array, n: Array) -> Array:
|
||||
# x > 1.
|
||||
_c = _lax_const
|
||||
BIG = _c(x, 1.44115188075855872e17)
|
||||
@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array:
|
||||
return d["ans"] * jnp.exp(-x)
|
||||
|
||||
|
||||
def _expn3(n: Array, x: Array) -> Array:
|
||||
def _expn3(x: Array, n: Array) -> Array:
|
||||
# n >= 5000
|
||||
_c = _lax_const
|
||||
one = _c(x, 1.0)
|
||||
@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
jnp.inf,
|
||||
one / n1, # prevent div by zero
|
||||
jnp.exp(-x) / x,
|
||||
partial(_expn3, n),
|
||||
partial(_expn2, n),
|
||||
partial(_expn1, n),
|
||||
_expn3,
|
||||
_expn2,
|
||||
_expn1,
|
||||
]
|
||||
ret = jnp.piecewise(x, conds, vals)
|
||||
ret = jnp.piecewise(x, conds, vals, n=n)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -305,7 +305,16 @@ class SetNameStackContextManager(contextlib.ContextDecorator):
|
||||
|
||||
|
||||
set_name_stack = SetNameStackContextManager
|
||||
reset_name_stack = lambda: SetNameStackContextManager(NameStack())
|
||||
|
||||
|
||||
# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
|
||||
# implementation doesn't work. Luckily this context manager isn't called much so
|
||||
# the performance shouldn't matter. See blame commit message for repro.
|
||||
# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
|
||||
@contextlib.contextmanager
|
||||
def reset_name_stack() -> Iterator[None]:
|
||||
with set_name_stack(NameStack()):
|
||||
yield
|
||||
|
||||
|
||||
class TransformNameStackContextManager(contextlib.ContextDecorator):
|
||||
|
@ -259,3 +259,26 @@ class NDIndexer:
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
return dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
# If there are explicit axes, we don't support changing the shape, so we
|
||||
# don't support int indexers and instead require all slices.
|
||||
if (self.int_indexer_shape or
|
||||
not all(isinstance(idx, Slice) for idx in self.indices)):
|
||||
raise TypeError("sharded ref (array reference) can only be indexed by "
|
||||
"slices, not integers")
|
||||
# Moreover, only allow trivial slice(None) slices on explicitly sharded
|
||||
# axes. Then the sharding stays the same.
|
||||
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||
for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)):
|
||||
if s is None: continue
|
||||
if not (type(sl.start) is int and sl.start == 0 and
|
||||
type(sl.size) is int and sl.size == d and
|
||||
type(sl.stride) is int and sl.stride == 1):
|
||||
raise ValueError("sharded ref (array reference) can only be sliced "
|
||||
f"along unsharded axes, but ref of shape {self.shape} "
|
||||
f"was sliced on axis {i}, which is sharded like {s}")
|
||||
return sharding
|
||||
|
@ -206,6 +206,13 @@ def _dtype_after_transforming(
|
||||
return dtype
|
||||
|
||||
|
||||
def _sharding_after_transforming(sharding, transforms):
|
||||
for transform in transforms:
|
||||
sharding = transform.transform_sharding(sharding)
|
||||
assert sharding is not None
|
||||
return sharding
|
||||
|
||||
|
||||
def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
tree):
|
||||
transforms = tree_util.tree_unflatten(tree, args)
|
||||
@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
|
||||
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
|
||||
# TODO(yashkatariya): Transform the sharding too instead of setting it to
|
||||
# None.
|
||||
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
|
||||
sharding=core.get_cur_mesh_sharding())
|
||||
out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms)
|
||||
out_aval = ref_aval.inner_aval.update(
|
||||
shape=out_shape, dtype=out_dtype, sharding=out_sharding)
|
||||
else:
|
||||
if transforms:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
@ -437,6 +443,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
# if type(ref_tangent) is ad_util.Zero:
|
||||
# raise Exception("you're an idiot")
|
||||
assert isinstance(ref_tangent.aval, AbstractRef)
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
return (swap_p.bind(ref_primal, x_primal, *idx, **params),
|
||||
@ -657,5 +665,14 @@ mlir.register_lowering(
|
||||
|
||||
# === AD rules for mutable arrays ===
|
||||
|
||||
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
|
||||
def _mut_jvp(primals, tangents):
|
||||
(init_val,), (init_val_dot,) = primals, tangents
|
||||
primal_out = core.mutable_array_p.bind(init_val)
|
||||
if type(init_val_dot) is ad_util.Zero:
|
||||
tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval))
|
||||
else:
|
||||
tangent_out = core.mutable_array_p.bind(init_val_dot)
|
||||
return primal_out, tangent_out
|
||||
|
||||
ad.primitive_jvps[core.mutable_array_p] = _mut_jvp
|
||||
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))
|
||||
|
@ -119,6 +119,12 @@ class RefBitcaster:
|
||||
del dtype # Unused
|
||||
return self.dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -166,6 +172,12 @@ class RefReshaper:
|
||||
del dtype # Unused
|
||||
return self.dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
# If there are no explicit axes, do nothing.
|
||||
if all(p is None for p in sharding.spec):
|
||||
return sharding
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Transform(Protocol):
|
||||
|
||||
@ -189,6 +201,10 @@ class Transform(Protocol):
|
||||
"""
|
||||
return dtype
|
||||
|
||||
def transform_sharding(self, sharding):
|
||||
if all(p is None for p in sharding.spec): return sharding # no explicit axes
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefIndexer:
|
||||
|
@ -1640,6 +1640,8 @@ class _LazyDtypes:
|
||||
float_dtypes += [_dtypes.float8_e4m3]
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
float_dtypes += [_dtypes.float8_e8m0fnu]
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
float_dtypes += [_dtypes.float4_e2m1fn]
|
||||
return self.supported(float_dtypes)
|
||||
|
||||
@_cached_property
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Colocated Python API."""
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
# pylint: disable=useless-import-alias
|
||||
from jax.experimental.colocated_python.api import (
|
||||
|
@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun(
|
||||
|
||||
devices = specialization.devices
|
||||
|
||||
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
|
||||
def lowered_fun(*args, **kwargs) -> jax.Array:
|
||||
result = info.fun(*args, **kwargs)
|
||||
result_leaves, out_treedef = tree_util.tree_flatten(result)
|
||||
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
|
||||
|
@ -13,11 +13,9 @@
|
||||
# limitations under the License.
|
||||
"""Colocated Python serialization utilities."""
|
||||
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections
|
||||
import functools
|
||||
import io
|
||||
@ -37,12 +35,6 @@ import numpy as np
|
||||
|
||||
DeviceList = xc.DeviceList
|
||||
|
||||
# Hard-coded limit for serialized specs size.
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
_MAX_SERIALIZED_SPECS_SIZE = 1048576
|
||||
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def _get_cpu_device_map() -> dict[int, jax.Device]:
|
||||
"""Returns a map from a device id to a matching device."""
|
||||
@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any:
|
||||
|
||||
def _make_specs_for_serialized_specs(
|
||||
devices: DeviceList,
|
||||
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
|
||||
) -> api.ShapeDtypeStruct:
|
||||
"""Makes output specs for serialized specs."""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
return (
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(), dtype=np.int32, sharding=replicated_sharding
|
||||
),
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
dtype=np.uint8,
|
||||
sharding=replicated_sharding,
|
||||
),
|
||||
return api.ShapeDtypeStruct(
|
||||
shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@ -209,49 +192,49 @@ def _serialize_specs(
|
||||
specs_treedef: tree_util.PyTreeDef,
|
||||
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
|
||||
devices: DeviceList,
|
||||
) -> tuple[jax.Array, ...]:
|
||||
"""Serializes the output specs into a tuple of arrays.
|
||||
) -> jax.Array:
|
||||
"""Serializes the output specs into a jax.Array of string type.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
s = _serialize((specs_treedef, specs_leaves))
|
||||
assert (
|
||||
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
|
||||
), f"Too large serialized spec size: {len(s)}"
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
if not hasattr(np.dtypes, "StringDType"):
|
||||
raise TypeError(
|
||||
"Serializing Colocated Python requires StringDType. Please use"
|
||||
" numpy to 2.0.0 or later, or explicityly provide an output spec"
|
||||
" function."
|
||||
)
|
||||
|
||||
s_bytes = _serialize((specs_treedef, specs_leaves))
|
||||
s_str = base64.b64encode(s_bytes).decode("ascii")
|
||||
s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore
|
||||
|
||||
# TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making
|
||||
# jax.Array via make_array_from_single_device_arrays. We should then use a
|
||||
# sharding that spans all the execution devices - not just the addressable
|
||||
# ones.
|
||||
addressable_devices = devices.addressable_device_list
|
||||
mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
len_array = jax.make_array_from_callback(
|
||||
shape=(),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.array(len(s), dtype=np.int32),
|
||||
|
||||
out_arrays = [
|
||||
jax.device_put(s_np_array, device) for device in addressable_devices
|
||||
]
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
arrays=out_arrays, sharding=replicated_sharding, shape=(),
|
||||
)
|
||||
data_array = jax.make_array_from_callback(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.frombuffer(
|
||||
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
)
|
||||
return len_array, data_array
|
||||
|
||||
|
||||
def _deserialize_specs(
|
||||
serialized_specs: tuple[jax.Array, ...],
|
||||
serialized_specs: jax.Array,
|
||||
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
|
||||
"""Deserializes the specs from the serialized specs.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
len_array, data_array = serialized_specs
|
||||
length = int(len_array.addressable_shards[0].data)
|
||||
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
|
||||
return _deserialize(data[:length])
|
||||
data_array = serialized_specs.addressable_shards[0].data
|
||||
data = base64.b64decode(data_array.item().encode("ascii"))
|
||||
return _deserialize(data)
|
||||
|
@ -1541,10 +1541,12 @@ tf_not_yet_impl = [
|
||||
"assert_consumed_value",
|
||||
"consume",
|
||||
"ragged_dot",
|
||||
"ragged_dot_general",
|
||||
"cholesky_update",
|
||||
"symmetric_product",
|
||||
"from_edtype",
|
||||
"to_edtype",
|
||||
"reciprocal",
|
||||
# Pallas TPU primitives
|
||||
"bitcast",
|
||||
"repeat",
|
||||
@ -3571,8 +3573,8 @@ def _pjit(*args: TfVal,
|
||||
in_shardings: Sequence[sharding.Sharding],
|
||||
out_shardings: Sequence[sharding.Sharding],
|
||||
in_layouts, out_layouts,
|
||||
resource_env: mesh.ResourceEnv,
|
||||
donated_invars,
|
||||
ctx_mesh,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
|
@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
from jax._src.lib.mlir.dialects import math as mlir_math
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
@ -234,11 +235,6 @@ def _vector_load_op_lowering_rule(
|
||||
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
|
||||
)
|
||||
|
||||
if not layouts.is_strided_fragmented_layout(out_layout_attr):
|
||||
raise ValueError(
|
||||
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
|
||||
)
|
||||
|
||||
for i in vector_load_op.indices:
|
||||
index_defining_op = i.owner.opview
|
||||
if (
|
||||
@ -254,9 +250,28 @@ def _vector_load_op_lowering_rule(
|
||||
element_type = vector_load_op.result.type.element_type
|
||||
is_signed = False if ir.IntegerType.isinstance(element_type) else None
|
||||
|
||||
fragmented_array = fa.FragmentedArray.load_strided(
|
||||
vector_load_op.base, is_signed=is_signed
|
||||
)
|
||||
if layouts.is_strided_fragmented_layout(out_layout_attr):
|
||||
strided_layout = layouts.from_strided_fragmented_layout_attr(
|
||||
out_layout_attr
|
||||
)
|
||||
fragmented_array = fa.FragmentedArray.load_strided(
|
||||
vector_load_op.base,
|
||||
is_signed=is_signed,
|
||||
vec_size=strided_layout.vec_size,
|
||||
)
|
||||
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
|
||||
layout = ir.MemRefType(vector_load_op.base.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
|
||||
transformed_ref = transform_memref(vector_load_op.base, transforms)
|
||||
fragmented_array = fa.FragmentedArray.load_tiled(
|
||||
transformed_ref,
|
||||
swizzle=swizzle,
|
||||
is_signed=is_signed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
|
||||
)
|
||||
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
|
||||
|
||||
|
||||
@ -424,10 +439,77 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
gmem_transform=transforms,
|
||||
uniform=True,
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
arrive=store_op.commit_group,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _conversion_op_lowering_rule(
|
||||
_: LoweringContext,
|
||||
op: ir.OpView,
|
||||
source_is_signed: bool | None,
|
||||
target_is_signed: bool | None,
|
||||
) -> Sequence[ir.Value]:
|
||||
[in_layout] = inference_utils.in_layouts(op)
|
||||
[layout] = inference_utils.out_layouts(op)
|
||||
if in_layout != layout:
|
||||
raise ValueError("Layout mismatch")
|
||||
|
||||
target_ty = op.result.type.element_type # pytype: disable=attribute-error
|
||||
operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed)
|
||||
converted = operand.astype(target_ty, is_signed=target_is_signed)
|
||||
return [_fragmented_array_to_ir(converted, op.result.type)]
|
||||
|
||||
|
||||
for op, source_is_signed, target_is_signed in [
|
||||
(arith.ExtFOp, None, None),
|
||||
(arith.ExtSIOp, True, True),
|
||||
(arith.ExtUIOp, False, False),
|
||||
(arith.FPToSIOp, None, True),
|
||||
(arith.FPToUIOp, None, False),
|
||||
(arith.SIToFPOp, True, None),
|
||||
(arith.TruncFOp, None, None),
|
||||
(arith.TruncIOp, False, False),
|
||||
(arith.UIToFPOp, False, None),
|
||||
]:
|
||||
_lowerings[op.OPERATION_NAME] = functools.partial(
|
||||
_conversion_op_lowering_rule,
|
||||
source_is_signed=source_is_signed,
|
||||
target_is_signed=target_is_signed,
|
||||
)
|
||||
|
||||
|
||||
def _unary_op_lowering_rule(
|
||||
_: LoweringContext,
|
||||
op: Any,
|
||||
impl: Callable[[fa.FragmentedArray], fa.FragmentedArray],
|
||||
is_signed: bool | None = None,
|
||||
) -> Sequence[ir.Value]:
|
||||
in_layouts = inference_utils.in_layouts(op)
|
||||
[layout] = inference_utils.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
kwargs = {}
|
||||
if hasattr(op, "fastmath"):
|
||||
kwargs = dict(
|
||||
approx=op.fastmath == ir.Attribute.parse("#arith.fastmath<afn>")
|
||||
)
|
||||
a = _fragmented_array_from_ir(op.operand, layout, is_signed)
|
||||
return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)]
|
||||
|
||||
|
||||
for op, impl, is_signed in [
|
||||
(mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None),
|
||||
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
|
||||
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
|
||||
(mlir_math.LogOp, fa.FragmentedArray.log, None),
|
||||
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
|
||||
]:
|
||||
_lowerings[op.OPERATION_NAME] = functools.partial(
|
||||
_unary_op_lowering_rule, impl=impl, is_signed=is_signed
|
||||
)
|
||||
|
||||
|
||||
def _binary_op_lowering_rule(
|
||||
_: LoweringContext,
|
||||
op: Any,
|
||||
@ -525,6 +607,25 @@ def _cmpf_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
||||
|
||||
|
||||
@_register_lowering(arith.BitcastOp)
|
||||
def _bitcast_op_lowering_rule(
|
||||
_: LoweringContext, op: arith.BitcastOp
|
||||
) -> Sequence[ir.Value]:
|
||||
in_layouts = inference_utils.in_layouts(op)
|
||||
[layout] = inference_utils.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
in_ = _fragmented_array_from_ir(op.in_, layout)
|
||||
out_element_type = ir.VectorType(op.result.type).element_type
|
||||
out = in_.bitcast(
|
||||
out_element_type,
|
||||
output_is_signed=False
|
||||
if ir.IntegerType.isinstance(out_element_type)
|
||||
else None,
|
||||
)
|
||||
return [_fragmented_array_to_ir(out, op.result.type)]
|
||||
|
||||
|
||||
@_register_lowering(mgpu.WGMMAOp)
|
||||
def _mgpu_wgmma_op_lowering_rule(
|
||||
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
||||
@ -689,8 +790,12 @@ def _for_op_lowering_rule(
|
||||
new_args = (new_for_op.induction_variable, *recreated_carry)
|
||||
for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True):
|
||||
old_carry.replace_all_uses_with(new_carry)
|
||||
for op in ops_to_lower:
|
||||
|
||||
for op in ops_to_lower:
|
||||
with ir.InsertionPoint(op):
|
||||
ctx.lower_op(op)
|
||||
|
||||
with ir.InsertionPoint(new_for_op.body):
|
||||
new_yield_operands = lower_carry(yield_op.operands)
|
||||
yield_op.erase()
|
||||
scf.yield_(new_yield_operands)
|
||||
|
@ -53,13 +53,12 @@ def build_kernel(
|
||||
index = ir.IndexType.get()
|
||||
|
||||
swizzle = 128
|
||||
tile_k = swizzle // 2
|
||||
swizzle_elems = tile_k = swizzle // 2
|
||||
tiling = (8, swizzle_elems)
|
||||
|
||||
in_dtype = jnp.float16
|
||||
k_loop_iter = k // tile_k
|
||||
max_concurrent_steps = min(max_concurrent_steps, k_loop_iter)
|
||||
tma_tile_m = 128
|
||||
tma_tile_kn = 64
|
||||
|
||||
block_tile_m = tile_m
|
||||
block_tile_n = tile_n
|
||||
@ -123,17 +122,14 @@ def build_kernel(
|
||||
src_ref=a,
|
||||
dst_ref=mgpu.memref_slice(a_smem, slot),
|
||||
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
|
||||
gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)),
|
||||
gmem_transform=mgpu.TileTransform(tiling),
|
||||
**common_args,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=b,
|
||||
dst_ref=mgpu.memref_slice(b_smem, slot),
|
||||
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
|
||||
gmem_transform=(
|
||||
mgpu.TileTransform((tma_tile_kn, tma_tile_kn)),
|
||||
mgpu.TransposeTransform((1, 0, 2, 3)),
|
||||
),
|
||||
gmem_transform=mgpu.TileTransform(tiling),
|
||||
**common_args,
|
||||
)
|
||||
|
||||
@ -145,7 +141,7 @@ def build_kernel(
|
||||
tcgen05.mma(
|
||||
acc,
|
||||
mgpu.memref_slice(a_smem, slot),
|
||||
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
|
||||
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)),
|
||||
a_swizzle=swizzle,
|
||||
b_swizzle=swizzle,
|
||||
accumulate=accumulate,
|
||||
@ -172,26 +168,23 @@ def build_kernel(
|
||||
src_ref=d_smem,
|
||||
dst_ref=d,
|
||||
gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)),
|
||||
gmem_transform=mgpu.TileTransform((128, 64)),
|
||||
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
compute_buffers = (
|
||||
jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k),
|
||||
(tma_tile_m, tma_tile_kn)),
|
||||
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling),
|
||||
jnp.float16),
|
||||
jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n),
|
||||
(tma_tile_kn, tma_tile_kn)),
|
||||
mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling),
|
||||
jnp.float16),
|
||||
)
|
||||
epilogue_buffer = jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
|
||||
mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)),
|
||||
jnp.float16)
|
||||
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
|
||||
assert block_tile_m == 128
|
||||
smem = (
|
||||
smem_buffers,
|
||||
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
|
||||
|
@ -22,6 +22,7 @@ import math
|
||||
from collections.abc import Callable
|
||||
from typing import Iterable, Protocol, Sequence, TypeVar
|
||||
|
||||
import itertools
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
@ -115,6 +116,65 @@ class Tiling:
|
||||
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
|
||||
return strides
|
||||
|
||||
def tile_nested_shape_strides(
|
||||
self,
|
||||
shape: tuple[tuple[int, ...], ...],
|
||||
strides: tuple[tuple[int, ...], ...],
|
||||
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
|
||||
"""A fused version of `tile_shape` and `tile_strides` for nested shapes.
|
||||
|
||||
By nested shape we mean that each logical dimension (i.e. each element of
|
||||
shape/strides) is actually composed out of multiple physical dimensions.
|
||||
For example, a row-major array of logical shape (128, 128) that is tiled
|
||||
into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each
|
||||
dim is split into two sub-dims) and nested strides of
|
||||
((2 * 64 * 64, 64), (64 * 64, 1)).
|
||||
"""
|
||||
if len(shape) != len(strides):
|
||||
raise ValueError(
|
||||
f"Shape {shape} and strides {strides} must have the same length"
|
||||
)
|
||||
def fail_if(cond, shape=shape): # Capture shape now.
|
||||
if cond:
|
||||
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
|
||||
for tile in self.tiles:
|
||||
fail_if(len(tile) > len(shape))
|
||||
untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):]
|
||||
untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):]
|
||||
major_dim_shapes, major_dim_strides = [], []
|
||||
minor_dim_shapes, minor_dim_strides = [], []
|
||||
for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides):
|
||||
major_dim_shape_rev, major_dim_stride_rev = [], []
|
||||
minor_dim_shape_rev, minor_dim_stride_rev = [], []
|
||||
for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True):
|
||||
if d < t: # We will need to tile more dims
|
||||
fail_if(t % d != 0)
|
||||
t //= d
|
||||
minor_dim_shape_rev.append(d)
|
||||
minor_dim_stride_rev.append(s)
|
||||
elif t != 1: # Last dim to tile!
|
||||
fail_if(d % t != 0)
|
||||
minor_dim_shape_rev.append(t)
|
||||
minor_dim_stride_rev.append(s)
|
||||
if d != t: # No need to insert singleton dims.
|
||||
major_dim_shape_rev.append(d // t)
|
||||
major_dim_stride_rev.append(s * t)
|
||||
t = 1
|
||||
else: # Done tiling!
|
||||
major_dim_shape_rev.append(d)
|
||||
major_dim_stride_rev.append(s)
|
||||
fail_if(t != 1)
|
||||
major_dim_shapes.append(major_dim_shape_rev[::-1])
|
||||
minor_dim_shapes.append(minor_dim_shape_rev[::-1])
|
||||
major_dim_strides.append(major_dim_stride_rev[::-1])
|
||||
minor_dim_strides.append(minor_dim_stride_rev[::-1])
|
||||
shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes)
|
||||
strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides)
|
||||
return (
|
||||
tuple(tuple(d) if d else (1,) for d in shape),
|
||||
tuple(tuple(d) if d else (1,) for d in strides),
|
||||
)
|
||||
|
||||
def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
|
||||
for tile in self.tiles:
|
||||
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
|
||||
@ -214,7 +274,7 @@ class TiledLayout:
|
||||
index = ir.IndexType.get()
|
||||
contig_strides = utils.get_contiguous_strides(shape)
|
||||
tile_strides = self.tiling.tile_strides(contig_strides)
|
||||
dyn_tile_strides = [c(s, i32) for s in tile_strides]
|
||||
dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]]
|
||||
warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides)
|
||||
lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides)
|
||||
dyn_offset = arith.addi(warp_offset, lane_offset)
|
||||
@ -246,7 +306,12 @@ class TiledLayout:
|
||||
so the tiled shape always ends with this suffix, no matter what array shape
|
||||
it's applied to.
|
||||
"""
|
||||
return self.tiling.tile_shape(self.base_tile_shape)
|
||||
base_tile_shape = self.base_tile_shape
|
||||
return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):]
|
||||
|
||||
@functools.cached_property
|
||||
def tiled_tiling_rank(self) -> int:
|
||||
return len(self.tiled_tiling_shape)
|
||||
|
||||
@property
|
||||
def vector_length(self) -> int:
|
||||
@ -292,16 +357,12 @@ class TiledLayout:
|
||||
|
||||
def warp_indices(self) -> tuple[ir.Value, ...]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
tiled_shape = tuple(
|
||||
d if i == self.warp_dim else 1
|
||||
for i, d in enumerate_negative(self.tiled_tiling_shape)
|
||||
)
|
||||
assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP
|
||||
tiled_shape_rank = len(self.tiled_tiling_shape)
|
||||
warp_idx = arith.remui(
|
||||
arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)),
|
||||
c(WARPS_IN_WARPGROUP, i32),
|
||||
)
|
||||
indices = [arith.constant(i32, 0)] * len(tiled_shape)
|
||||
indices = [arith.constant(i32, 0)] * tiled_shape_rank
|
||||
indices[self.warp_dim] = warp_idx
|
||||
return tuple(indices)
|
||||
|
||||
@ -479,6 +540,33 @@ TILED_LAYOUT_WGMMA = TiledLayout(
|
||||
lane_dims=(-4, -3),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This tiled layout is similar to the one above. Above, each warp stores a 8x8
|
||||
# submatrix in the following way (we only show the first 4 rows for brevity):
|
||||
#
|
||||
# 0 0 1 1 2 2 3 3
|
||||
# 4 4 5 5 6 6 7 7
|
||||
# 8 8 9 9 10 10 11 11
|
||||
# 12 12 13 13 14 14 15 15
|
||||
# ...
|
||||
#
|
||||
# This tiled layout stores the same 8x8 submatrix in the following way:
|
||||
#
|
||||
# 0 4 1 5 2 6 3 7
|
||||
# 0 4 1 5 2 6 3 7
|
||||
# 8 12 9 13 10 14 11 15
|
||||
# 8 12 9 13 10 14 11 15
|
||||
# ...
|
||||
#
|
||||
# You can see that we have taken 2x2 submatrices from the above layout and
|
||||
# transposed them. The assigment of lanes to elements is such that in both
|
||||
# layouts the same two lanes map to a single 2x2 submatrix, making the transpose
|
||||
# very cheap (one shuffle and permute suffices to change between those layouts).
|
||||
WGMMA_TRANSPOSED_LAYOUT = TiledLayout(
|
||||
Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))),
|
||||
warp_dim=-10,
|
||||
lane_dims=(-6, -3, -5),
|
||||
vector_dim=-2,
|
||||
)
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True)
|
||||
@ -553,13 +641,22 @@ class FragmentedArray:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None):
|
||||
def load_strided(
|
||||
cls,
|
||||
ref: ir.Value,
|
||||
*,
|
||||
is_signed: bool | None = None,
|
||||
vec_size: int | None = None,
|
||||
):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise TypeError(ref.type)
|
||||
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
shape = tuple(ref_ty.shape)
|
||||
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
|
||||
if vec_size is None:
|
||||
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
|
||||
else:
|
||||
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
|
||||
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
|
||||
try:
|
||||
# Flattening the reference potentially produces simpler PTX but
|
||||
@ -647,9 +744,35 @@ class FragmentedArray:
|
||||
|
||||
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
|
||||
"""
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if self.layout == new_layout:
|
||||
return self
|
||||
shape = self.shape
|
||||
if (
|
||||
self.layout == TILED_LAYOUT_WGMMA
|
||||
and new_layout == WGMMA_TRANSPOSED_LAYOUT
|
||||
and utils.bitwidth(self.mlir_dtype) == 16
|
||||
):
|
||||
is_even_row = arith.cmpi(
|
||||
arith.CmpIPredicate.eq,
|
||||
arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)),
|
||||
c(0, i32),
|
||||
)
|
||||
perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32))
|
||||
new_regs = []
|
||||
for reg in self.registers.flat:
|
||||
reg_ty = reg.type
|
||||
reg = utils.bitcast(reg, i32)
|
||||
reg_shfl = utils.shfl_bfly(reg, 4)
|
||||
new_reg = llvm.inline_asm(
|
||||
i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
|
||||
)
|
||||
new_regs.append(utils.bitcast(new_reg, reg_ty))
|
||||
return FragmentedArray(
|
||||
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
|
||||
_layout=new_layout,
|
||||
_is_signed=self.is_signed,
|
||||
)
|
||||
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0:
|
||||
tiled_layout = _tiled_wgmma_layout(shape)
|
||||
if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or (
|
||||
@ -966,6 +1089,24 @@ class FragmentedArray:
|
||||
return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32"))
|
||||
return self._pointwise(mlir_math.exp2)
|
||||
|
||||
def log(self, *, approx: bool = False):
|
||||
if not ir.FloatType.isinstance(self.mlir_dtype):
|
||||
raise NotImplementedError
|
||||
if approx:
|
||||
dtype = self.mlir_dtype
|
||||
ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453))
|
||||
return self.log2(approx=True) * ln2
|
||||
return self._pointwise(mlir_math.log)
|
||||
|
||||
def log2(self, *, approx: bool = False):
|
||||
if not ir.FloatType.isinstance(self.mlir_dtype):
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
if approx:
|
||||
if not ir.F32Type.isinstance(self.mlir_dtype):
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32"))
|
||||
return self._pointwise(mlir_math.log2)
|
||||
|
||||
def sin(self, *, approx: bool = False):
|
||||
if not ir.FloatType.isinstance(self.mlir_dtype):
|
||||
raise NotImplementedError
|
||||
@ -1190,7 +1331,30 @@ class FragmentedArray:
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
to_integer = ir.IntegerType.isinstance(new_dtype)
|
||||
if from_float and to_float:
|
||||
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
||||
cur_ty_width = ir.FloatType(cur_dtype).width
|
||||
new_ty_width = ir.FloatType(new_dtype).width
|
||||
if cur_ty_width == new_ty_width:
|
||||
# There is no instruction to perform conversions between two float types
|
||||
# of the same width. Go through the next-larger standard type.
|
||||
# TODO(bchetioui): support conversions between float types of width 8.
|
||||
# Which larger type to pick will depend on the number of bits in the
|
||||
# smallest exponent.
|
||||
if cur_ty_width != 16:
|
||||
raise NotImplementedError(
|
||||
"Conversion between float types of width other than 16 not"
|
||||
" supported"
|
||||
)
|
||||
larger_ty = ir.F32Type.get()
|
||||
match self.layout:
|
||||
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
|
||||
shape = ir.VectorType(self.registers.flat[0].type).shape
|
||||
upcast_ty = ir.VectorType.get(shape, larger_ty)
|
||||
case WGMMARowFragLayout() | WGSplatFragLayout():
|
||||
upcast_ty = larger_ty
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x))
|
||||
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
||||
convert = arith.truncf
|
||||
else:
|
||||
convert = arith.extf
|
||||
@ -1423,19 +1587,34 @@ class FragmentedArray:
|
||||
if create_array:
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
|
||||
|
||||
def store_untiled(self, ref: ir.Value):
|
||||
def debug_print(self, fmt: str):
|
||||
idx_fmt = ", ".join(["{}"] * len(self.shape))
|
||||
@self.foreach
|
||||
def _(val, idx):
|
||||
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
|
||||
utils.debug_print(fmt_str, *idx, val, uniform=False)
|
||||
|
||||
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise ValueError(ref)
|
||||
|
||||
def vs_unsupported():
|
||||
if not vector_store:
|
||||
raise NotImplementedError(
|
||||
f"Can't use non-vector stores with layout {self.layout}"
|
||||
)
|
||||
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
self._store_untiled_wgmma(ref)
|
||||
case WGSplatFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_splat(ref)
|
||||
case WGStridedFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_wg_strided(ref)
|
||||
case TiledLayout():
|
||||
self._store_untiled_tiled(ref)
|
||||
self._store_untiled_tiled(ref, vector_store=vector_store)
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
@ -1502,7 +1681,7 @@ class FragmentedArray:
|
||||
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
|
||||
memref.store(value, ref, [row, col])
|
||||
|
||||
def _store_untiled_tiled(self, ref: ir.Value):
|
||||
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
||||
if utils.bitwidth(self.mlir_dtype) < 8:
|
||||
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
|
||||
@ -1510,7 +1689,7 @@ class FragmentedArray:
|
||||
layout = self.layout
|
||||
assert isinstance(layout, TiledLayout)
|
||||
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
|
||||
if ref_strides[layout.vector_dim] != 1:
|
||||
if vector_store and ref_strides[layout.vector_dim] != 1:
|
||||
raise NotImplementedError(
|
||||
"Can't use vector stores with non-unit minormost stride"
|
||||
)
|
||||
@ -1524,16 +1703,30 @@ class FragmentedArray:
|
||||
raise NotImplementedError(f"Unexpected ref space {ref_space}")
|
||||
ptr = utils.memref_ptr(ref, memory_space=memory_space)
|
||||
# Fold warp and lane offsets into the pointer once, since they are dynamic.
|
||||
dyn_strides = [arith.constant(i32, s) for s in strides]
|
||||
dyn_strides = [
|
||||
arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :]
|
||||
]
|
||||
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides)
|
||||
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides)
|
||||
dyn_offset = arith.addi(warp_offset, lane_offset)
|
||||
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
|
||||
# All warp tile offsets are static and can be fused into the store.
|
||||
for tile_idx, reg in np.ndenumerate(self.registers):
|
||||
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
|
||||
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
||||
llvm.store(reg, reg_ptr)
|
||||
if vector_store:
|
||||
elems = [reg]
|
||||
else:
|
||||
index = ir.IndexType.get()
|
||||
elems = [
|
||||
vector.extractelement(reg, position=c(i, index))
|
||||
for i in range(ir.VectorType(reg.type).shape[0])
|
||||
]
|
||||
for i, e in enumerate(elems):
|
||||
tile_idx_local = list(tile_idx)
|
||||
tile_idx_local[layout.vector_dim] += i
|
||||
tile_idx_local = list(tile_idx_local)
|
||||
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
|
||||
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
||||
llvm.store(e, reg_ptr)
|
||||
|
||||
def store_tiled(self, ref, swizzle: int | None):
|
||||
match self.layout:
|
||||
@ -1714,31 +1907,42 @@ class FragmentedArray:
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
dtype = ref_ty.element_type
|
||||
if ref_ty.rank % 2:
|
||||
raise ValueError("Tiled refence must have even rank")
|
||||
ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:])
|
||||
raise ValueError("Tiled reference must have even rank")
|
||||
ref_logical_rank = ref_ty.rank // 2
|
||||
ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:])
|
||||
ref_tiling = Tiling((ref_tiling_shape,))
|
||||
ref_strides, _ = ref_ty.get_strides_and_offset()
|
||||
if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape:
|
||||
raise ValueError()
|
||||
if len(layout.base_tile_shape) > len(ref_tiling_shape):
|
||||
raise ValueError("Memory tiling must be a multiple of the register tiling")
|
||||
ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):]
|
||||
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
|
||||
raise ValueError(
|
||||
f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the"
|
||||
f" register tiling ({layout.base_tile_shape})"
|
||||
)
|
||||
nested_ref_shape = tuple(
|
||||
(ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank])
|
||||
for i in range(ref_logical_rank)
|
||||
)
|
||||
nested_ref_strides = tuple(
|
||||
(ref_strides[i], ref_strides[i + ref_logical_rank])
|
||||
for i in range(ref_logical_rank)
|
||||
)
|
||||
tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides(
|
||||
nested_ref_shape, nested_ref_strides
|
||||
)
|
||||
|
||||
elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
|
||||
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
|
||||
# We could technically handle this case, but it would be quite complicated.
|
||||
# If tiling dimensions would have to be expanded into multiple, we'd have to
|
||||
# adjust the dimension indices in layouts, including expanding some of them
|
||||
# into multiple indices. Note that for non-tiling dims, we allow the shape
|
||||
# to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx.
|
||||
if any(
|
||||
len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :]
|
||||
):
|
||||
raise NotImplementedError("Memory and register tiling incompatible")
|
||||
tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape))
|
||||
elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides))
|
||||
elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims]
|
||||
lane_shape = [tiled_shape[d] for d in layout.lane_dims]
|
||||
if elem_tiled_strides[layout.vector_dim] != 1:
|
||||
raise ValueError("Stride of the vectorized dimension should be 1")
|
||||
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
|
||||
tiled_shape[d] = 1
|
||||
full_tiling = Tiling((ref_tiling_shape, *tiling.tiles))
|
||||
full_layout = dataclasses.replace(layout, tiling=full_tiling)
|
||||
|
||||
element_bits = mgpu.bitwidth(dtype)
|
||||
if (layout.vector_length * element_bits) % 8 != 0:
|
||||
@ -1779,9 +1983,11 @@ class FragmentedArray:
|
||||
)
|
||||
|
||||
# All offsets are in units of transfer_dtype.
|
||||
dyn_tiled_strides = [c(s) for s in transfer_tiled_strides]
|
||||
lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides)
|
||||
warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides)
|
||||
dyn_tiled_strides = [
|
||||
c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :]
|
||||
]
|
||||
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides)
|
||||
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides)
|
||||
dyn_offset = arith.addi(lane_offset, warp_offset)
|
||||
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
|
||||
raise ValueError("Tiled stores can be performed into SMEM")
|
||||
@ -1809,10 +2015,23 @@ class FragmentedArray:
|
||||
reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype)
|
||||
offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle))
|
||||
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype)
|
||||
reg_idxs = [
|
||||
tiling.tile_indices(full_tiling.untile_indices(idx))
|
||||
for idx in indices.tolist()
|
||||
]
|
||||
# Here, registers are organized in an array with shape obtained by tiling
|
||||
# the logical data bounds. But, the reference was tiled and so each
|
||||
# logical tiled dimension can map to multiple dims in tiled_shape.
|
||||
# The transform below maps this potentially higher-rank representation
|
||||
# back to the lower-rank representation used by the register arrays.
|
||||
def mem_idx_to_reg_idx(idx):
|
||||
reg_tiled_idx = []
|
||||
base_idx = 0
|
||||
for dim_shape in tiled_nested_shape[:ref_logical_rank]:
|
||||
dim_strides = utils.get_contiguous_strides(dim_shape)
|
||||
dim_idxs = idx[base_idx:base_idx + len(dim_shape)]
|
||||
base_idx += len(dim_shape)
|
||||
reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides)))
|
||||
# We should have fixed up all but the tiling dims.
|
||||
assert base_idx == len(idx) - layout.tiled_tiling_rank
|
||||
return (*reg_tiled_idx, *idx[base_idx:])
|
||||
reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()]
|
||||
def get_register(regs, reg_idxs=reg_idxs):
|
||||
return plan.select([regs[reg_idx] for reg_idx in reg_idxs])
|
||||
def update_registers(regs, new, reg_idxs=reg_idxs):
|
||||
|
@ -430,8 +430,8 @@ class LaunchContext:
|
||||
gmem_ref, smem_ref = dst_ref, src_ref
|
||||
if barrier is not None:
|
||||
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
|
||||
if arrive is not None:
|
||||
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
|
||||
if arrive is None:
|
||||
arrive = True # Commit this copy to the async group by default
|
||||
else:
|
||||
raise ValueError("Only SMEM <-> GMEM copies supported")
|
||||
# TODO(apaszke): This is a very approximate check. Improve it!
|
||||
@ -683,7 +683,8 @@ class LaunchContext:
|
||||
nvvm.cp_async_bulk_tensor_global_shared_cta(
|
||||
tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate
|
||||
)
|
||||
nvvm.cp_async_bulk_commit_group()
|
||||
if arrive:
|
||||
nvvm.cp_async_bulk_commit_group()
|
||||
|
||||
def await_async_copy(
|
||||
self, allow_groups: int, await_read_only: bool = False
|
||||
|
@ -18,18 +18,23 @@ from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import cast
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import math as mlir_math
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from . import fragmented_array as fa
|
||||
from . import inference_utils
|
||||
from . import layouts as layouts_lib
|
||||
from . import utils
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
@ -192,22 +197,44 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
|
||||
|
||||
for op in [
|
||||
arith.AddIOp, arith.AddFOp,
|
||||
arith.AddIOp,
|
||||
arith.AddFOp,
|
||||
arith.AndIOp,
|
||||
arith.BitcastOp,
|
||||
arith.CmpFOp,
|
||||
arith.CmpIOp,
|
||||
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,
|
||||
arith.ExtFOp,
|
||||
arith.ExtSIOp,
|
||||
arith.ExtUIOp,
|
||||
arith.FPToSIOp,
|
||||
arith.FPToUIOp,
|
||||
arith.MaximumFOp,
|
||||
arith.MaxUIOp, arith.MaxSIOp,
|
||||
arith.MaxUIOp,
|
||||
arith.MaxSIOp,
|
||||
arith.MinimumFOp,
|
||||
arith.MinUIOp, arith.MinSIOp,
|
||||
arith.MulIOp, arith.MulFOp,
|
||||
arith.MinUIOp,
|
||||
arith.MinSIOp,
|
||||
arith.MulIOp,
|
||||
arith.MulFOp,
|
||||
arith.OrIOp,
|
||||
arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp,
|
||||
arith.RemUIOp, arith.RemSIOp, arith.RemFOp,
|
||||
arith.SubIOp, arith.SubFOp,
|
||||
arith.TruncFOp, arith.TruncIOp,
|
||||
arith.FloorDivSIOp,
|
||||
arith.DivUIOp,
|
||||
arith.DivFOp,
|
||||
arith.RemUIOp,
|
||||
arith.RemSIOp,
|
||||
arith.RemFOp,
|
||||
arith.SIToFPOp,
|
||||
arith.UIToFPOp,
|
||||
arith.SubIOp,
|
||||
arith.SubFOp,
|
||||
arith.TruncFOp,
|
||||
arith.TruncIOp,
|
||||
arith.XOrIOp,
|
||||
mlir_math.ExpOp,
|
||||
mlir_math.Exp2Op,
|
||||
mlir_math.LogOp,
|
||||
mlir_math.RsqrtOp,
|
||||
mlir_math.TanhOp,
|
||||
vector.LoadOp,
|
||||
vector.StoreOp,
|
||||
]:
|
||||
@ -487,11 +514,36 @@ def infer_layout(module: ir.Module):
|
||||
# propagated. However, it is possible for some operations to remain
|
||||
# unannotated---for example, if there were no annotations on any operation in
|
||||
# the module at the start of this function. We annotate all the remaining ops
|
||||
# that should be annotated with a strided fragmented layout.
|
||||
# that should be annotated with a strided fragmented layout, whose vector size
|
||||
# is derived from the narrowest type and vector size used in the program. We
|
||||
# make sure to derive a single vector size in order to avoid relayouts at
|
||||
# lowering time.
|
||||
default_vector_size = math.inf
|
||||
|
||||
def update_default_vector_size(op: ir.OpView):
|
||||
nonlocal default_vector_size
|
||||
for v in list(op.operands) + list(op.results):
|
||||
if ir.VectorType.isinstance(v.type):
|
||||
max_vec_size_for_v = (
|
||||
np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE
|
||||
)
|
||||
desired_vec_size = 8 // utils.bytewidth(v.type.element_type)
|
||||
default_vector_size = min(
|
||||
default_vector_size, max_vec_size_for_v, desired_vec_size
|
||||
)
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, update_default_vector_size)
|
||||
|
||||
if default_vector_size is None: # Nothing to annotate.
|
||||
return
|
||||
|
||||
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
|
||||
if not ir.VectorType.isinstance(ty):
|
||||
return None
|
||||
layout = fa.WGStridedFragLayout.from_shaped_type(ty)
|
||||
layout = fa.WGStridedFragLayout(
|
||||
shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size
|
||||
)
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(layout)
|
||||
|
||||
def set_default_layout(op: ir.OpView):
|
||||
|
223
jax/experimental/mosaic/gpu/mma_utils.py
Normal file
223
jax/experimental/mosaic/gpu/mma_utils.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import enum
|
||||
import math
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
|
||||
from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
def tiled_memref_shape(ref: ir.Value):
|
||||
"""Returns the 2D untiled shape and element type of a tiled 4D memref."""
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
if ref_ty.rank != 4:
|
||||
raise ValueError(f"Expected a 4D memref, got: {ref_ty}")
|
||||
logical_shape = (
|
||||
ref_ty.shape[0] * ref_ty.shape[2], ref_ty.shape[1] * ref_ty.shape[3]
|
||||
)
|
||||
return logical_shape, ref_ty.element_type
|
||||
|
||||
|
||||
class Dim(enum.Enum):
|
||||
K = enum.auto()
|
||||
MN = enum.auto()
|
||||
|
||||
|
||||
def create_descriptor(
|
||||
ref: ir.Value,
|
||||
swizzle: int,
|
||||
group_size: tuple[int, int], # Instruction group size on each operand dim.
|
||||
logical_k_major: bool, # False for LHS, True for RHS.
|
||||
# Soft deprecated. Use small tiling instead.
|
||||
large_tile: tuple[int, int] | None = None,
|
||||
):
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
element_bytewidth = utils.bytewidth(ref_ty.element_type)
|
||||
swizzle_elems = swizzle // element_bytewidth
|
||||
ref_strides, _ = ref_ty.get_strides_and_offset()
|
||||
ref_byte_strides = [s * element_bytewidth for s in ref_strides]
|
||||
mn_large_tile = k_large_tile = None
|
||||
if logical_k_major:
|
||||
_, mn_tiles, k_tiling, mn_tiling = ref_ty.shape
|
||||
k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = (
|
||||
ref_byte_strides
|
||||
)
|
||||
k_group_size, mn_group_size = group_size
|
||||
if large_tile is not None:
|
||||
k_large_tile, mn_large_tile = large_tile
|
||||
else:
|
||||
mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape
|
||||
mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = (
|
||||
ref_byte_strides
|
||||
)
|
||||
mn_group_size, k_group_size = group_size
|
||||
if large_tile is not None:
|
||||
mn_large_tile, k_large_tile = large_tile
|
||||
|
||||
IGNORED = 0
|
||||
MMA_ATOM_ROWS = 8
|
||||
MMA_BYTEWIDTH_K = 32
|
||||
mma_width_k = MMA_BYTEWIDTH_K // element_bytewidth
|
||||
# As far as I can tell (which does not seem to fully align with the way MMA is
|
||||
# documented in PTX docs), MMA expects the data to be tiled into matrices
|
||||
# of shape 8 x swizzle_elems, with swizzle_elems dim being the fastest
|
||||
# changing. I call this submatrix an MMA atom.
|
||||
#
|
||||
# The role of the SMEM descriptor is to specify the striding pattern between
|
||||
# those atoms. The fastest changing dimension is called the "leading"
|
||||
# dimension and it specifies the stride between consecutive atoms that share
|
||||
# the same coordinate along that dim. The slower dimension is called a
|
||||
# "stride" dimension.
|
||||
if (
|
||||
large_tile is not None
|
||||
and k_large_tile == k_tiling
|
||||
and (mn_large_tile == mn_tiling or mn_tiles == 1 and mn_tiling < mn_large_tile)
|
||||
# There are configurations where large tiles are same size as small ones.
|
||||
# We use the small path since it has fewer restrictions.
|
||||
and set(large_tile) != {MMA_ATOM_ROWS, swizzle_elems}
|
||||
): # Large tiles.
|
||||
if (
|
||||
k_tiling_stride == element_bytewidth
|
||||
and mn_tiling_stride == k_tiling * element_bytewidth
|
||||
):
|
||||
fastest_dim = Dim.K
|
||||
leading_byte_offset = IGNORED # TC assumes K to be contiguous here.
|
||||
# MMA atoms in a group are contiguous, so we increment by the MMA atom
|
||||
# size. However, we only have one level of striding, and so if the group
|
||||
# size exceeds a single large tile (and there is more than one tile) then
|
||||
# that tiled dimension must be contiguous after tiles or else we would
|
||||
# need another striding level.
|
||||
if (
|
||||
mn_tiles > 1
|
||||
and mn_group_size > mn_tiling
|
||||
and mn_tile_stride != math.prod(large_tile) * element_bytewidth
|
||||
):
|
||||
raise ValueError(
|
||||
"MMA layout with large tiles that is K-fastest only supports"
|
||||
" multiple MN tiles when the tiled MN dimension is a contiguous"
|
||||
" stack of tiles "
|
||||
f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)} * {element_bytewidth})"
|
||||
)
|
||||
stride_byte_offset = MMA_ATOM_ROWS * swizzle
|
||||
desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous.
|
||||
elif (
|
||||
k_tiling_stride == k_tiling * element_bytewidth
|
||||
and mn_tiling_stride == element_bytewidth
|
||||
):
|
||||
if k_large_tile != mn_large_tile:
|
||||
raise ValueError(
|
||||
"MMA layout with large tiles that is MN-fastest is only supported"
|
||||
" when the tiling is square"
|
||||
)
|
||||
fastest_dim = Dim.MN
|
||||
# Next swizzle atom with the same K coordinate is in the next MN tile.
|
||||
leading_byte_offset = mn_tile_stride
|
||||
# MMA atoms in a group are contiguous and a group does not exceed a tile.
|
||||
assert k_large_tile == k_group_size
|
||||
stride_byte_offset = MMA_ATOM_ROWS * swizzle
|
||||
# Each row is swizzle bytes wide, and we read mma_width_k rows at a time.
|
||||
assert mn_large_tile == swizzle // element_bytewidth
|
||||
desc_k_stride = mma_width_k * swizzle
|
||||
else:
|
||||
raise ValueError("MMA tiles must be contiguous")
|
||||
else: # Small tiles.
|
||||
if k_tiling_stride > mn_tiling_stride:
|
||||
slower_tiling, faster_tiling = k_tiling, mn_tiling
|
||||
else:
|
||||
faster_tiling, slower_tiling = k_tiling, mn_tiling
|
||||
if slower_tiling != MMA_ATOM_ROWS or faster_tiling != swizzle_elems:
|
||||
raise ValueError(
|
||||
f"Tiling should be ({MMA_ATOM_ROWS}, swizzle_elems) where"
|
||||
f" swizzle_elems = swizzle // bytewidth(dtype) (= {swizzle} //"
|
||||
f" {element_bytewidth} = {swizzle_elems}), but got ({slower_tiling},"
|
||||
f" {faster_tiling})"
|
||||
)
|
||||
if k_tiling_stride == element_bytewidth and mn_tiling_stride == swizzle:
|
||||
fastest_dim = Dim.K
|
||||
leading_byte_offset = IGNORED # TC assumes K to be contiguous here.
|
||||
stride_byte_offset = mn_tile_stride
|
||||
desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous.
|
||||
elif k_tiling_stride == swizzle and mn_tiling_stride == element_bytewidth:
|
||||
fastest_dim = Dim.MN
|
||||
leading_byte_offset = mn_tile_stride
|
||||
stride_byte_offset = k_tile_stride
|
||||
k_tiles_per_mma = mma_width_k // MMA_ATOM_ROWS
|
||||
desc_k_stride = k_tile_stride * k_tiles_per_mma
|
||||
else:
|
||||
raise ValueError("MMA tiles must be contiguous")
|
||||
desc_base = encode_descriptor(
|
||||
ref,
|
||||
leading_byte_offset=leading_byte_offset,
|
||||
stride_byte_offset=stride_byte_offset,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling)
|
||||
assert not rem
|
||||
mn_group_stride = mn_tile_stride * mn_tiles_per_group
|
||||
k_tiles_per_group, rem = divmod(k_group_size, k_tiling)
|
||||
assert not rem
|
||||
k_group_stride = k_tile_stride * k_tiles_per_group
|
||||
|
||||
return (
|
||||
(desc_base, desc_k_stride),
|
||||
(mn_group_stride, k_group_stride),
|
||||
fastest_dim,
|
||||
)
|
||||
|
||||
|
||||
def encode_addr(x: int):
|
||||
result = (x & 0x3FFFF) >> 4
|
||||
if result << 4 != x:
|
||||
raise ValueError(f"Cannot encode value in an MMA descriptor: {x}")
|
||||
return result
|
||||
|
||||
|
||||
def encode_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
const_init: int = 0,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, 3))
|
||||
c = lambda x: arith.constant(i64, x)
|
||||
if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
swizzle_encoding = 0
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
||||
swizzle_encoding = 1
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle:
|
||||
swizzle_encoding = 2
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle:
|
||||
swizzle_encoding = 3
|
||||
else:
|
||||
raise NotImplementedError(swizzle)
|
||||
encoded_base_addr = llvm.lshr(llvm.and_(ptr_val, c(0x3FFFF)), c(4))
|
||||
# We ignore the offset
|
||||
desc_const = (
|
||||
const_init
|
||||
| (encode_addr(leading_byte_offset) << 16)
|
||||
| (encode_addr(stride_byte_offset) << 32)
|
||||
)
|
||||
desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const))
|
||||
desc = llvm.or_(encoded_base_addr, desc)
|
||||
return desc
|
@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import math
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
@ -27,7 +26,7 @@ import numpy as np
|
||||
|
||||
from . import utils
|
||||
from . import fragmented_array as fa
|
||||
from . import _wgmma
|
||||
from . import mma_utils
|
||||
from .launch_context import LaunchContext
|
||||
|
||||
# MyPy does a terrible job with the MLIR API.
|
||||
@ -37,21 +36,6 @@ from .launch_context import LaunchContext
|
||||
TMEM_ROWS = 128
|
||||
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
|
||||
|
||||
def create_smem_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
):
|
||||
return _wgmma.create_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset,
|
||||
stride_byte_offset,
|
||||
swizzle,
|
||||
memory_space=3,
|
||||
const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
|
||||
)
|
||||
|
||||
def create_instr_descriptor(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -100,70 +84,126 @@ def mma(
|
||||
collective: bool = False,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
if a_swizzle != b_swizzle:
|
||||
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
|
||||
swizzle = a_swizzle
|
||||
num_cta = 2 if collective else 1
|
||||
|
||||
# Step 1. Establish the shape and element type of the operation.
|
||||
if not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"A must be a memref, got {a.type}")
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
if a_swizzle != b_swizzle:
|
||||
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
||||
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
"MMA requires A and B to have the same contraction dimension (K),"
|
||||
f" got: {k2} and {k}"
|
||||
)
|
||||
if element_type != element_type2:
|
||||
raise ValueError(
|
||||
"MMA requires A and B to have the same element type, got:"
|
||||
f" {element_type2} and {element_type}"
|
||||
)
|
||||
if d.shape != (m, n * num_cta):
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
|
||||
)
|
||||
f32 = ir.F32Type.get()
|
||||
if element_type == f32 or element_type == ir.BF16Type.get():
|
||||
if d.dtype != f32:
|
||||
raise ValueError(
|
||||
f"MMA with element type {element_type} only supports accumulators"
|
||||
f" of type f32, but got: {d.dtype}"
|
||||
)
|
||||
elif element_type == ir.F16Type.get():
|
||||
if d.dtype != element_type and d.dtype != f32:
|
||||
raise ValueError(
|
||||
"MMA with element type f16 only supports accumulators of type f32"
|
||||
f" or f16, but got: {d.dtype}"
|
||||
)
|
||||
|
||||
m_group_size = d.layout.elements_in_tile[0]
|
||||
if m_group_size != 128:
|
||||
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
||||
# instructions must be issued in groups of the same width as the swizzle.
|
||||
m_group_elems = d.layout.elements_in_tile[0]
|
||||
if m_group_elems != 128:
|
||||
raise NotImplementedError("Only 128-row accumulators supported for now")
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_groups, k_groups, n_groups),
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
|
||||
mma_params,
|
||||
) = _wgmma._validate_mma(
|
||||
a,
|
||||
b,
|
||||
a_swizzle,
|
||||
m_group_size=m_group_size,
|
||||
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
|
||||
)
|
||||
n_group_size = n // n_groups
|
||||
if n > 512:
|
||||
raise ValueError(f"N is too big: at most 512 is supported, but got {n}")
|
||||
num_cta = 2 if collective else 1
|
||||
k_group_elems = swizzle // utils.bytewidth(element_type)
|
||||
if n % 8:
|
||||
raise ValueError(f"N must be a multiple of 8, got: {n}")
|
||||
elif n > 256 and n != 512:
|
||||
raise ValueError("Only N below 256 or N=512 are supported")
|
||||
if num_cta == 2 and n > 256:
|
||||
raise NotImplementedError(
|
||||
"N is too big for collective MMA. Only up to 256 is supported."
|
||||
)
|
||||
n_group_elems = min(n, 256)
|
||||
if m % m_group_elems:
|
||||
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
||||
if k % k_group_elems:
|
||||
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
||||
if n % n_group_elems:
|
||||
raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}")
|
||||
m_groups = m // m_group_elems
|
||||
k_groups = k // k_group_elems
|
||||
n_groups = n // n_group_elems
|
||||
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
|
||||
wgmma_element_type = (
|
||||
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
||||
)
|
||||
|
||||
# TODO(apaszke): Verify that the cluster shape matches the expectation of
|
||||
# collective MMA.
|
||||
expected_acc_shape = (m, n * num_cta)
|
||||
if d.shape != expected_acc_shape:
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}"
|
||||
)
|
||||
# Step 3. Compute the operand descriptors.
|
||||
(
|
||||
(a_desc_base, a_k_instr_stride),
|
||||
(a_m_group_stride, a_k_group_stride),
|
||||
a_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
a,
|
||||
swizzle=swizzle,
|
||||
group_size=(m_group_elems, k_group_elems),
|
||||
logical_k_major=False,
|
||||
)
|
||||
(
|
||||
(b_desc_base, b_k_instr_stride),
|
||||
(b_n_group_stride, b_k_group_stride),
|
||||
b_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
b,
|
||||
swizzle=swizzle,
|
||||
group_size=(k_group_elems, n_group_elems),
|
||||
logical_k_major=True,
|
||||
)
|
||||
|
||||
# Step 4. Issue the instructions.
|
||||
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
||||
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
|
||||
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
|
||||
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
|
||||
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
|
||||
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
|
||||
b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64))
|
||||
if m_groups != 1:
|
||||
raise NotImplementedError("D needs to be sliced")
|
||||
acc = accumulate if ki == 0 else true
|
||||
_do_mma(
|
||||
d.slice(
|
||||
slice(None), utils.ds(ni * n_group_size, n_group_size)
|
||||
slice(None), utils.ds(ni * n_group_elems, n_group_elems)
|
||||
).address,
|
||||
a_mk,
|
||||
b_nk,
|
||||
d_type=ir.F32Type.get(),
|
||||
m=m_group_size,
|
||||
m=m_group_elems,
|
||||
n=n_group_elems,
|
||||
collective=collective,
|
||||
**mma_params,
|
||||
a_transpose=a_fastest != mma_utils.Dim.K,
|
||||
b_transpose=b_fastest != mma_utils.Dim.K,
|
||||
a_k_stride=a_k_instr_stride,
|
||||
b_k_stride=b_k_instr_stride,
|
||||
accumulate=acc,
|
||||
swizzle=swizzle,
|
||||
element_type=wgmma_element_type,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1172,4 +1172,32 @@ def getelementptr(
|
||||
|
||||
|
||||
def dyn_dot(x, y):
|
||||
assert len(x) == len(y)
|
||||
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
|
||||
|
||||
|
||||
def shfl_bfly(x: ir.Value, distance: int | ir.Value):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if isinstance(distance, int):
|
||||
distance = c(distance, i32)
|
||||
assert x.type == i32
|
||||
return nvvm.shfl_sync(
|
||||
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
|
||||
)
|
||||
|
||||
|
||||
def bitcast(x: ir.Value, new_type: ir.Type):
|
||||
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
|
||||
new_type = ir.IntegerType(new_type)
|
||||
x_ty = ir.VectorType(x.type)
|
||||
assert new_type.width == bitwidth(x_ty.element_type) * math.prod(x_ty.shape)
|
||||
i0 = arith.ConstantOp.create_index(0)
|
||||
return vector.extractelement(
|
||||
vector.bitcast(ir.VectorType.get((1,), new_type), x), position=i0
|
||||
)
|
||||
if ir.IntegerType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
|
||||
new_type = ir.VectorType(new_type)
|
||||
x_ty = ir.IntegerType(x.type)
|
||||
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
||||
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
|
@ -14,13 +14,10 @@
|
||||
# ==============================================================================
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
@ -29,6 +26,7 @@ from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from . import fragmented_array as fa
|
||||
from . import mma_utils
|
||||
from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
@ -84,60 +82,6 @@ class WGMMAAccumulator:
|
||||
return cls(_value=value[0], _sync=False)
|
||||
|
||||
|
||||
def wgmma_encode(x: int):
|
||||
result = (x & 0x3FFFF) >> 4
|
||||
if result << 4 != x:
|
||||
raise ValueError(f"Cannot encode value in a WGMMA descriptor: {x}")
|
||||
return result
|
||||
|
||||
|
||||
def llvm_add(x, y):
|
||||
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
||||
|
||||
|
||||
def create_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
memory_space: int | None = None,
|
||||
const_init: int = 0,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
|
||||
if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
swizzle_encoding = 0
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
||||
swizzle_encoding = 1
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle:
|
||||
swizzle_encoding = 2
|
||||
elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle:
|
||||
swizzle_encoding = 3
|
||||
else:
|
||||
raise NotImplementedError(swizzle)
|
||||
encoded_base_addr = llvm.LShrOp(
|
||||
llvm.AndOp(ptr_val, c(0x3FFFF, i64)).result, c(4, i64)
|
||||
)
|
||||
# We ignore the offset
|
||||
desc_const = (
|
||||
const_init
|
||||
| (wgmma_encode(leading_byte_offset) << 16)
|
||||
| (wgmma_encode(stride_byte_offset) << 32)
|
||||
)
|
||||
desc = llvm.or_(
|
||||
arith.shli(c(swizzle_encoding, i64), c(62, i64)), c(desc_const, i64)
|
||||
)
|
||||
desc = llvm.or_(encoded_base_addr.result, desc)
|
||||
return desc
|
||||
|
||||
|
||||
def _unpack_i32(vec_ty, r):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return vector.bitcast(
|
||||
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
|
||||
)
|
||||
|
||||
|
||||
def _supported_wgmma_types(dtype, abtype) -> bool:
|
||||
input_types_are = lambda ty: ty.isinstance(abtype)
|
||||
if ir.F32Type.isinstance(dtype):
|
||||
@ -271,14 +215,14 @@ def wgmma_m64(
|
||||
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
|
||||
else:
|
||||
if i > 0:
|
||||
a = llvm_add(
|
||||
a = _llvm_add(
|
||||
a,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
|
||||
)
|
||||
a_args = [a]
|
||||
# Advance the B descriptor.
|
||||
if i > 0:
|
||||
b_descriptor = llvm_add(
|
||||
b_descriptor = _llvm_add(
|
||||
b_descriptor,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
|
||||
)
|
||||
@ -297,262 +241,6 @@ def wgmma_m64(
|
||||
return to_acc_vec_regs(acc_regs)
|
||||
|
||||
|
||||
class WGMMALayout(enum.Enum):
|
||||
ROW_MAJOR = enum.auto()
|
||||
COL_MAJOR = enum.auto()
|
||||
|
||||
|
||||
def _validate_mma(
|
||||
a: Any,
|
||||
b: ir.Value,
|
||||
swizzle: int,
|
||||
m_group_size: int, # The M used by a single instruction.
|
||||
descriptor_const_init: int = 0,
|
||||
):
|
||||
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
|
||||
# instruction's K width.
|
||||
if swizzle < 32:
|
||||
raise ValueError(f"Unsupported swizzle: {swizzle}")
|
||||
|
||||
# Get A type.
|
||||
if a_in_smem := isinstance(a, ir.Value):
|
||||
if not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}")
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
a_element_type = a_ty.element_type
|
||||
a_shape = tuple(a_ty.shape)
|
||||
if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
|
||||
raise ValueError("A must be in workgroup memory when it's a reference")
|
||||
if len(a_shape) != 4:
|
||||
raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}")
|
||||
elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"):
|
||||
a_element_type = a.mlir_dtype
|
||||
a_shape = a.shape
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported A type: {type(a)}")
|
||||
|
||||
# Get B type (always a reference).
|
||||
b_ty = ir.MemRefType(b.type)
|
||||
if b_ty.rank != 4:
|
||||
raise ValueError(f"B must be 4D, got rank {b_ty.rank}")
|
||||
|
||||
# Veirfy element types and compute the tiling.
|
||||
if (element_type := a_element_type) != b_ty.element_type:
|
||||
raise ValueError(
|
||||
f"A and B must have the same element type, got: {a_element_type} and"
|
||||
f" {b_ty.element_type}"
|
||||
)
|
||||
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
|
||||
if element_type not in supported_types:
|
||||
raise ValueError(a_element_type)
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
swizzle_elems = swizzle // element_bytewidth
|
||||
|
||||
# Verify the shape and strides of B are as expected.
|
||||
b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape
|
||||
k = b_k_tiles * b_k_tiling
|
||||
n = n_tiles * n_tiling
|
||||
|
||||
b_strides, _ = b_ty.get_strides_and_offset()
|
||||
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
||||
b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides
|
||||
if (
|
||||
b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth
|
||||
and n_tiles != 1 # When there's only one tile, we never jump between them
|
||||
):
|
||||
raise ValueError("B tiles must be contiguous along the N dimension")
|
||||
if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest
|
||||
b_order = WGMMALayout.ROW_MAJOR
|
||||
# This first case (n_tiles == 1) is to allow the somewhat weird case of
|
||||
# loading a small amount of N-fastest data, that needs to be padded to a
|
||||
# larger tile due to swizzle. In this case we allow slicing the big tile
|
||||
# before WGMMA to avoid unnecessary compute on padding.
|
||||
if n_tiles == 1:
|
||||
if n_tiling % 8:
|
||||
raise ValueError("N tile size must be a multiple of 8")
|
||||
elif n_tiling != swizzle_elems:
|
||||
raise ValueError(
|
||||
"Row major RHS (N-fastest) requires the N tile size to be equal to"
|
||||
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
|
||||
)
|
||||
if b_k_tiling not in {8, swizzle_elems}:
|
||||
raise ValueError(
|
||||
"Row major RHS (N-fastest) requires the K tile size to be either"
|
||||
f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}"
|
||||
)
|
||||
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
|
||||
b_order = WGMMALayout.COL_MAJOR
|
||||
if b_k_tiling != swizzle_elems:
|
||||
raise ValueError(
|
||||
"Column major RHS (K-fastest) requires the K tile size to be equal"
|
||||
f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}"
|
||||
)
|
||||
# See the explanation in the N-fastest case when n_tiles == 1.
|
||||
if n_tiles == 1:
|
||||
if n_tiling % 8:
|
||||
raise ValueError("N tile size must be a multiple of 8")
|
||||
elif n_tiling not in {8, swizzle_elems}:
|
||||
raise ValueError(
|
||||
"Column major RHS (K-fastest) requires the N tile size to be either"
|
||||
f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
if n > 256 and n % 256:
|
||||
raise ValueError(
|
||||
f"N group size must be a multiple of 256 when larger than 256, got: {n}"
|
||||
)
|
||||
k_group_size = swizzle_elems
|
||||
n_group_size = min(n, 256)
|
||||
b_k_tiles_per_group = k_group_size // b_k_tiling
|
||||
b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group
|
||||
n_tiles_per_group = n_group_size // n_tiling
|
||||
b_n_group_stride = b_n_byte_stride * n_tiles_per_group
|
||||
|
||||
# Verify the shape and strides of A are as expected.
|
||||
if not a_in_smem:
|
||||
m = a_shape[0]
|
||||
a_order = a_m_group_stride = a_k_group_stride = None
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
|
||||
m = m_tiles * m_tiling
|
||||
# TODO(apaszke): I'm not actually convinced that we need this check.
|
||||
if m_tiling != m_group_size:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}"
|
||||
)
|
||||
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
|
||||
raise ValueError(a_ty.shape)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [
|
||||
s * element_bytewidth for s in a_strides
|
||||
]
|
||||
if a_tile_byte_strides == [swizzle, element_bytewidth]:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
elif a_tile_byte_strides == [element_bytewidth, swizzle]:
|
||||
a_order = WGMMALayout.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(a_strides)
|
||||
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
|
||||
# Not sure what the layout is like, since the tiles aren't square.
|
||||
raise NotImplementedError
|
||||
a_m_tiles_per_group = m_group_size // m_tiling
|
||||
a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group
|
||||
a_k_tiles_per_group = k_group_size // a_k_tiling
|
||||
a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group
|
||||
|
||||
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
|
||||
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
|
||||
# This is the number of rows until consecutive repeats of the swizzle pattern.
|
||||
swizzle_pattern_rows = swizzle // 16
|
||||
# A swizzle atom is a 2D matrix with the dimensions below.
|
||||
swizzle_atom_bytes = swizzle_pattern_rows * 128
|
||||
|
||||
# Here "leading" refers to the fastest changing dimension. There are two
|
||||
# strides we have to define per value:
|
||||
# Leading byte offset (LBO)
|
||||
# K-fastest: ignored
|
||||
# MN-fastest: stride between consecutive swizzle atoms that share the same
|
||||
# K coordinate.
|
||||
# Stride byte offset (SBO)
|
||||
# As far as I can tell this is just the offset between two consecutive
|
||||
# swizzle atoms along the non-leading dimension.
|
||||
IGNORED = 0
|
||||
a_desc_fields = dict(
|
||||
# I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it
|
||||
# is documented in the PTX docs, and my best explanation for the other
|
||||
# case is that the instruction has a fixed shape and so it does not care
|
||||
# about strides. It's possible that it's an artifact of the fact that we
|
||||
# use tiling of 64.
|
||||
leading_byte_offset=IGNORED,
|
||||
stride_byte_offset=swizzle_atom_bytes,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
# If B is N-fastest, all swizzle atoms within a tile share the same N
|
||||
# coordinate, so we simply take the stride between consecutive N tiles.
|
||||
# If B is K-fastest, all swizzle atoms within a tile share the same K
|
||||
# coordinate, which forces us to lay out the tiles in N-fastest order or else
|
||||
# they would have uneven strides.
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
|
||||
# N tiles are contiguous, so the next N swizzle atom follows immediately.
|
||||
# K tiles are not contiguous, so we take the stride between them.
|
||||
stride_byte_offset=swizzle_atom_bytes
|
||||
if b_k_fastest or b_k_tiling == swizzle_elems
|
||||
else b_k_byte_stride,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
# The K strides indicate the stride between the consecutive places where all
|
||||
# coordinates are 0 except for K being incremented by the instruction width.
|
||||
# If an input is K-fastest, we increment the descriptor by 32 bytes, since
|
||||
# that is the K-width of all MMA instructions.
|
||||
if b_k_fastest:
|
||||
b_k_wgmma_stride = 32
|
||||
elif b_k_tiling == swizzle_elems:
|
||||
# When B is N-fastest and we use the large square tiling, the relevant
|
||||
# slices all fall within the first tile. A single MMA instruction for 16-bit
|
||||
# types reads a subtile of shape 16x(swizzle bytes), giving us the necessary
|
||||
# expression.
|
||||
assert n_tiling == swizzle_elems or n_tiles == 1
|
||||
b_k_wgmma_stride = swizzle * 16
|
||||
else:
|
||||
# If we use the small non-square tiling and N-fastest layout, each tile only
|
||||
# contains a single swizzle atom with the K coordinate. But, each tile has
|
||||
# 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles.
|
||||
b_k_wgmma_stride = b_k_byte_stride * 2
|
||||
wgmma_params = dict(
|
||||
a_transpose=not a_k_fastest,
|
||||
b_transpose=not b_k_fastest,
|
||||
# TODO(apaszke): This explanation is quite bad. We should better figure
|
||||
# out how to do LHS transposes.
|
||||
# We only support swizzle=128 for M-fastest A. In this case the tile is
|
||||
# swizzle x 64 (= swizzle elems) and so we just take a quarter of its size.
|
||||
a_k_stride=32 if a_k_fastest else swizzle * 16,
|
||||
b_k_stride=b_k_wgmma_stride,
|
||||
swizzle=swizzle,
|
||||
n=n_group_size,
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
else element_type,
|
||||
)
|
||||
if not a_in_smem:
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
a_desc_base = None
|
||||
else:
|
||||
a_desc_base = create_descriptor(
|
||||
a, **a_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
b_desc_base = create_descriptor(
|
||||
b, **b_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
|
||||
if m % m_group_size:
|
||||
raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}")
|
||||
m_groups = m // m_group_size
|
||||
if k % k_group_size:
|
||||
raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}")
|
||||
k_groups = k // k_group_size
|
||||
if n % n_group_size:
|
||||
raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}")
|
||||
n_groups = n // n_group_size
|
||||
|
||||
return (
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_groups, k_groups, n_groups),
|
||||
# Group strides are always in bytes!
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
|
||||
wgmma_params,
|
||||
)
|
||||
|
||||
|
||||
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
|
||||
# transpositions from memref strides.
|
||||
def wgmma(
|
||||
acc: WGMMAAccumulator,
|
||||
a: fa.FragmentedArray | ir.Value,
|
||||
@ -570,61 +258,129 @@ def wgmma(
|
||||
The refs must be contiguous or be contiguous except for having their two minor
|
||||
dimensions swapped.
|
||||
"""
|
||||
a_in_regs = isinstance(a, fa.FragmentedArray)
|
||||
if not a_in_regs and not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"Unsupported A type: {type(a)}")
|
||||
# Step 1. Establish the shape and element type of the operation.
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
|
||||
m_group_size = 64 # Hopper has a fixed M instruction shape.
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_groups, k_groups, n_groups),
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, _),
|
||||
wgmma_params,
|
||||
) = _validate_mma(a, b, swizzle, m_group_size=m_group_size)
|
||||
|
||||
if n_groups > 1:
|
||||
raise ValueError("N is too big for WGMMA. Only up to 256 is supported.")
|
||||
|
||||
if a_in_regs:
|
||||
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
||||
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
||||
m, k2 = a.shape
|
||||
element_type2 = a.mlir_dtype
|
||||
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
||||
raise ValueError(
|
||||
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
|
||||
)
|
||||
if a.shape[0] % m_group_size:
|
||||
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
|
||||
a_m_group_stride = a_k_group_stride = None
|
||||
|
||||
elif ir.MemRefType.isinstance(a.type):
|
||||
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
||||
else:
|
||||
raise ValueError(f"Unsupported A type: {type(a)}")
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
"WGMMA requires A and B to have the same contraction dimension (K),"
|
||||
f" got: {k2} and {k}"
|
||||
)
|
||||
if element_type != element_type2:
|
||||
raise ValueError(
|
||||
"WGMMA requires A and B to have the same element type, got:"
|
||||
f" {element_type2} and {element_type}"
|
||||
)
|
||||
if acc.value.shape != (m, n):
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
|
||||
)
|
||||
f32 = ir.F32Type.get()
|
||||
if element_type == f32 or element_type == ir.BF16Type.get():
|
||||
if acc.value.mlir_dtype != f32:
|
||||
raise ValueError(
|
||||
f"WGMMA with element type {element_type} only supports accumulators"
|
||||
f" of type f32, but got: {acc.value.mlir_dtype}"
|
||||
)
|
||||
elif element_type == ir.F16Type.get():
|
||||
if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32:
|
||||
raise ValueError(
|
||||
"WGMMA with element type f16 only supports accumulators of type f32"
|
||||
f" or f16, but got: {acc.value.mlir_dtype}"
|
||||
)
|
||||
|
||||
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
||||
# instructions must be issued in groups of the same width as the swizzle.
|
||||
m_group_elems = 64 # Hopper has a fixed M instruction shape.
|
||||
k_group_elems = swizzle // utils.bytewidth(element_type)
|
||||
if n > 256 or n % 8:
|
||||
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
||||
n_group_elems = n # We assume only one N group below.
|
||||
if m % m_group_elems:
|
||||
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
||||
if k % k_group_elems:
|
||||
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
||||
m_groups = m // m_group_elems
|
||||
k_groups = k // k_group_elems
|
||||
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
|
||||
wgmma_element_type = (
|
||||
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
||||
)
|
||||
|
||||
# Step 3. Compute the operand descriptors.
|
||||
if a_in_regs:
|
||||
a_desc_base = a_m_group_stride = a_k_group_stride = None
|
||||
a_instr_params = dict(a_transpose=None, a_k_stride=None)
|
||||
else:
|
||||
(
|
||||
(a_desc_base, a_k_instr_stride),
|
||||
(a_m_group_stride, a_k_group_stride),
|
||||
a_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
a,
|
||||
swizzle=swizzle,
|
||||
large_tile=(m_group_elems, k_group_elems),
|
||||
group_size=(m_group_elems, k_group_elems),
|
||||
logical_k_major=False,
|
||||
)
|
||||
a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
|
||||
a_k_stride=a_k_instr_stride)
|
||||
(
|
||||
(b_desc_base, b_k_instr_stride),
|
||||
(b_n_group_stride, b_k_group_stride),
|
||||
b_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
b,
|
||||
swizzle=swizzle,
|
||||
large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n.
|
||||
group_size=(k_group_elems, n_group_elems),
|
||||
logical_k_major=True,
|
||||
)
|
||||
del b_n_group_stride # We only support one N group.
|
||||
|
||||
# Step 4. Issue the instructions.
|
||||
if a_in_regs:
|
||||
a = wgmma_fence(a) # Make sure the registers are ready.
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
new_acc_regs = acc.value.registers.copy()
|
||||
k_group_size = k // k_groups
|
||||
for mi in range(m_groups):
|
||||
for ki in range(k_groups):
|
||||
if a_in_regs:
|
||||
a_mk = a[
|
||||
mi * m_group_size : (mi + 1) * m_group_size,
|
||||
ki * k_group_size : (ki + 1) * k_group_size,
|
||||
mi * m_group_elems : (mi + 1) * m_group_elems,
|
||||
ki * k_group_elems : (ki + 1) * k_group_elems,
|
||||
]
|
||||
else:
|
||||
a_mk = llvm_add(
|
||||
a_desc_base,
|
||||
c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64),
|
||||
a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = _llvm_add(
|
||||
a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
|
||||
)
|
||||
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
|
||||
b_k = _llvm_add(
|
||||
b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
|
||||
)
|
||||
new_acc_regs[mi : mi + 1] = wgmma_m64(
|
||||
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
|
||||
new_acc_regs[mi : mi + 1],
|
||||
a_mk,
|
||||
b_k,
|
||||
swizzle=swizzle,
|
||||
n=n_group_elems,
|
||||
element_type=wgmma_element_type,
|
||||
b_transpose=b_fastest != mma_utils.Dim.K,
|
||||
b_k_stride=b_k_instr_stride,
|
||||
**a_instr_params,
|
||||
)
|
||||
return WGMMAAccumulator(
|
||||
_value=fa.FragmentedArray(
|
||||
@ -668,3 +424,14 @@ def _as_i32_reg(v):
|
||||
def _lc(x):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
||||
|
||||
|
||||
def _llvm_add(x, y):
|
||||
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
||||
|
||||
|
||||
def _unpack_i32(vec_ty, r):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return vector.bitcast(
|
||||
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
|
||||
)
|
||||
|
@ -53,6 +53,7 @@ from jax._src.pallas.primitives import max_contiguous as max_contiguous
|
||||
from jax._src.pallas.primitives import multiple_of as multiple_of
|
||||
from jax._src.pallas.primitives import num_programs as num_programs
|
||||
from jax._src.pallas.primitives import program_id as program_id
|
||||
from jax._src.pallas.primitives import reciprocal as reciprocal
|
||||
from jax._src.pallas.primitives import run_scoped as run_scoped
|
||||
from jax._src.pallas.primitives import store as store
|
||||
from jax._src.pallas.primitives import swap as swap
|
||||
|
@ -35,6 +35,7 @@ from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arri
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
|
||||
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
|
||||
|
721
jax/experimental/pallas/ops/tpu/ragged_paged_attention.py
Normal file
721
jax/experimental/pallas/ops/tpu/ragged_paged_attention.py
Normal file
@ -0,0 +1,721 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TPU-Friendly Ragged Paged Attention kernel.
|
||||
|
||||
This kernel offers a highly optimized implementation of ragged paged attention,
|
||||
specifically designed for TPU and compatible with a wide range of model
|
||||
specifications. It supports mixed prefill and decoding, enhancing throughput
|
||||
during inference.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
||||
|
||||
|
||||
class MultiPageAsyncCopyDescriptor:
|
||||
"""Descriptor for async copy of multiple K/V pages from HBM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim]
|
||||
vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
|
||||
sem,
|
||||
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
|
||||
offset, # [seq_idx, kv_pages_start]
|
||||
):
|
||||
self._vmem_buf = vmem_buf
|
||||
seq_id, kv_pages_start = offset
|
||||
self._async_copies = [
|
||||
pltpu.make_async_copy(
|
||||
pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]],
|
||||
vmem_buf.at[i],
|
||||
sem,
|
||||
)
|
||||
for i in range(vmem_buf.shape[0])
|
||||
]
|
||||
|
||||
def start(self):
|
||||
"""Starts the async copies."""
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.start()
|
||||
|
||||
def wait(self):
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.wait()
|
||||
return self._vmem_buf
|
||||
|
||||
|
||||
def ref_ragged_paged_attention(
|
||||
queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1],
|
||||
*,
|
||||
sm_scale: float = 1.0,
|
||||
mask_value: float = DEFAULT_MASK_VALUE,
|
||||
):
|
||||
_, _, num_kv_heads, head_dim = k_pages.shape
|
||||
num_q_heads = queries.shape[1]
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
num_query_per_kv = num_q_heads // num_kv_heads
|
||||
outputs = []
|
||||
for i in range(num_seqs[0]):
|
||||
q_start = cu_q_lens[i]
|
||||
q_end = cu_q_lens[i + 1]
|
||||
q_len = q_end - q_start
|
||||
kv_len = kv_lens[i]
|
||||
indices = page_indices[i]
|
||||
q = queries[q_start:q_end]
|
||||
k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
|
||||
v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
|
||||
k = jnp.repeat(k, num_query_per_kv, axis=1)
|
||||
v = jnp.repeat(v, num_query_per_kv, axis=1)
|
||||
attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32)
|
||||
attn *= sm_scale
|
||||
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
|
||||
jnp.int32, attn.shape, 1
|
||||
)
|
||||
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
|
||||
attn += jnp.where(q_span < kv_span, mask_value, 0.0)
|
||||
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
|
||||
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
|
||||
outputs.append(out)
|
||||
|
||||
return jnp.concatenate(outputs, axis=0)
|
||||
|
||||
|
||||
# Expect to run these checkes during runtime.
|
||||
def validate_inputs_on_runtime(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs, # i32[1]
|
||||
):
|
||||
check_inputs_shapes(
|
||||
q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs
|
||||
)
|
||||
max_num_batched_tokens = q.shape[0]
|
||||
page_size = k_pages.shape[1]
|
||||
max_num_seqs, pages_per_seq = page_indices.shape
|
||||
if num_seqs[0] > max_num_seqs:
|
||||
raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
|
||||
max_kv_len = jnp.max(kv_lens)
|
||||
min_pages_per_seq = ceil_div(max_kv_len, page_size)
|
||||
if pages_per_seq < min_pages_per_seq:
|
||||
raise ValueError(
|
||||
f"{pages_per_seq=} must be greater or equal to"
|
||||
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}."
|
||||
)
|
||||
if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
|
||||
raise ValueError(
|
||||
f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
|
||||
f" {max_num_batched_tokens=}."
|
||||
)
|
||||
for i in range(num_seqs[0]):
|
||||
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
if q_len > kv_len:
|
||||
raise ValueError(
|
||||
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
|
||||
)
|
||||
|
||||
|
||||
# Expect to run these checks during compile time.
|
||||
def check_inputs_shapes(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs, # i32[1]
|
||||
):
|
||||
_, num_q_heads, head_dim = q.shape
|
||||
_, _, num_kv_heads, head_dim_k = k_pages.shape
|
||||
max_num_seqs, _ = page_indices.shape
|
||||
if num_seqs.shape != (1,):
|
||||
raise ValueError(f"{num_seqs.shape=} must be (1,)")
|
||||
if k_pages.shape != v_pages.shape:
|
||||
raise ValueError(
|
||||
f"{k_pages.shape=} and {v_pages.shape=} must have the same shape."
|
||||
)
|
||||
if head_dim_k != head_dim:
|
||||
raise ValueError(
|
||||
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
|
||||
)
|
||||
if kv_lens.shape != (max_num_seqs,):
|
||||
raise ValueError(
|
||||
f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
|
||||
" `max_num_seqs` is `page_indices.shape[0]`."
|
||||
)
|
||||
if cu_q_lens.shape != (max_num_seqs + 1,):
|
||||
raise ValueError(
|
||||
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
|
||||
" `max_num_seqs` is `page_indices.shape[0]`."
|
||||
)
|
||||
if (
|
||||
kv_lens.dtype != jnp.int32
|
||||
or page_indices.dtype != jnp.int32
|
||||
or cu_q_lens.dtype != jnp.int32
|
||||
):
|
||||
raise ValueError(
|
||||
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
|
||||
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
|
||||
f" {cu_q_lens.dtype=}."
|
||||
)
|
||||
if num_q_heads % num_kv_heads != 0:
|
||||
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
|
||||
|
||||
|
||||
def ragged_paged_attention_kernel(
|
||||
# Prefetch
|
||||
kv_lens_ref, # [max_num_seqs]
|
||||
page_indices_ref, # [max_num_seqs, pages_per_seq]
|
||||
cu_q_lens_ref, # [max_num_seqs + 1]
|
||||
seq_buf_idx_ref,
|
||||
# TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
|
||||
num_seqs_ref,
|
||||
# Input
|
||||
q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
||||
k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
# Output
|
||||
o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
||||
# Scratch
|
||||
k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
|
||||
v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
|
||||
sems, # [2, 2]
|
||||
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
*,
|
||||
sm_scale: float,
|
||||
mask_value: float,
|
||||
):
|
||||
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
|
||||
num_seqs = num_seqs_ref[0]
|
||||
_, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape
|
||||
num_kv_per_blk = num_kv_pages_per_blk * page_size
|
||||
num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
|
||||
heads_blk_idx, q_blk_idx = (
|
||||
pl.program_id(0),
|
||||
pl.program_id(1),
|
||||
)
|
||||
num_heads_blks = pl.num_programs(0)
|
||||
init_seq_idx = seq_buf_idx_ref[0]
|
||||
init_buf_idx = seq_buf_idx_ref[1]
|
||||
q_len_start = q_blk_idx * num_q_per_blk
|
||||
q_len_end = q_len_start + num_q_per_blk
|
||||
|
||||
def create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, seq_idx, kv_blk_idx, buf_idx
|
||||
):
|
||||
offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk)
|
||||
heads_start = heads_blk_idx * num_kv_heads_per_blk
|
||||
async_copy_k = MultiPageAsyncCopyDescriptor(
|
||||
k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :],
|
||||
k_bufs.at[buf_idx],
|
||||
sems.at[buf_idx, 0],
|
||||
page_indices_ref,
|
||||
offset,
|
||||
)
|
||||
async_copy_v = MultiPageAsyncCopyDescriptor(
|
||||
v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :],
|
||||
v_bufs.at[buf_idx],
|
||||
sems.at[buf_idx, 1],
|
||||
page_indices_ref,
|
||||
offset,
|
||||
)
|
||||
return async_copy_k, async_copy_v
|
||||
|
||||
# TODO(jevinjiang): Add these to Mosaic:
|
||||
# 1. Support arbitrary strided load/store for any dtype.
|
||||
# 2. Support arbitrary strided load/store for any last dimension.
|
||||
def strided_load_kv(ref, start, step):
|
||||
if ref.dtype == jnp.float32:
|
||||
return ref[start::step, :]
|
||||
packing = get_dtype_packing(ref.dtype)
|
||||
assert ref.dtype == jnp.bfloat16
|
||||
assert step % packing == 0
|
||||
b_start = start // packing
|
||||
b_offset = start % packing
|
||||
b_step = step // packing
|
||||
b_ref = ref.bitcast(jnp.int32)
|
||||
b = b_ref[b_start::b_step, :]
|
||||
bw = 32 // packing
|
||||
b = jnp.right_shift(b, bw * b_offset)
|
||||
b = jnp.left_shift(b, bw * (packing - 1))
|
||||
return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16)
|
||||
|
||||
def fold_on_2nd_minor(vec):
|
||||
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
|
||||
assert len(vec.shape) >= 2
|
||||
last_dim = vec.shape[-1]
|
||||
packing = get_dtype_packing(vec.dtype)
|
||||
if vec.shape[-2] % packing != 0:
|
||||
vec = vec.astype(jnp.float32)
|
||||
return vec.reshape(-1, last_dim)
|
||||
|
||||
@pl.when(heads_blk_idx + q_blk_idx == 0)
|
||||
def prefetch_first_kv_blk():
|
||||
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, init_seq_idx, 0, init_buf_idx
|
||||
)
|
||||
async_copy_k.start()
|
||||
async_copy_v.start()
|
||||
|
||||
def is_cur_q_blk_needed(q_states):
|
||||
done, cur_seq_idx, _ = q_states
|
||||
return jnp.logical_and(done == 0, cur_seq_idx < num_seqs)
|
||||
|
||||
def compute_with_cur_q_blk(q_states):
|
||||
done, cur_seq_idx, cur_buf_idx = q_states
|
||||
q_start = cu_q_lens_ref[cur_seq_idx]
|
||||
q_end = cu_q_lens_ref[cur_seq_idx + 1]
|
||||
q_len = q_end - q_start
|
||||
kv_len = kv_lens_ref[cur_seq_idx]
|
||||
|
||||
def get_next_prefetch_ids(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
):
|
||||
next_kv_blk_idx = kv_blk_idx + 1
|
||||
is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
|
||||
next_kv_blk_idx = lax.select(
|
||||
is_last_kv_blk,
|
||||
0,
|
||||
next_kv_blk_idx,
|
||||
)
|
||||
is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
|
||||
next_seq_idx = lax.select(
|
||||
is_last_kv_blk,
|
||||
lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx),
|
||||
cur_seq_idx,
|
||||
)
|
||||
is_last_seq = next_seq_idx == num_seqs
|
||||
next_seq_idx = lax.select(
|
||||
is_last_seq,
|
||||
0,
|
||||
next_seq_idx,
|
||||
)
|
||||
next_heads_blk_idx = lax.select(
|
||||
is_last_seq,
|
||||
heads_blk_idx + 1,
|
||||
heads_blk_idx,
|
||||
)
|
||||
next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
|
||||
return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
|
||||
|
||||
def flash_attention(
|
||||
q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
|
||||
k, # [num_kv_per_blk, head_dim]
|
||||
v, # [num_kv_per_blk, head_dim]
|
||||
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
|
||||
*,
|
||||
kv_blk_idx,
|
||||
):
|
||||
assert q.shape == (
|
||||
num_q_per_blk * num_q_heads_per_kv_head,
|
||||
head_dim,
|
||||
)
|
||||
assert k.shape == (
|
||||
num_kv_per_blk,
|
||||
head_dim,
|
||||
), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}"
|
||||
assert v.shape == (num_kv_per_blk, head_dim)
|
||||
assert head_m_ref.shape == (
|
||||
num_q_per_blk * num_q_heads_per_kv_head,
|
||||
128,
|
||||
)
|
||||
assert head_l_ref.shape == (
|
||||
num_q_per_blk * num_q_heads_per_kv_head,
|
||||
128,
|
||||
)
|
||||
assert head_o_ref.shape == (
|
||||
num_q_per_blk,
|
||||
num_q_heads_per_kv_head,
|
||||
head_dim,
|
||||
)
|
||||
kv_len_start = kv_blk_idx * num_kv_per_blk
|
||||
|
||||
def masked_store(ref, val, start, end, group=1):
|
||||
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
|
||||
mask = jnp.logical_and(iota >= start, iota < end)
|
||||
pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask)
|
||||
|
||||
qk = (
|
||||
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
|
||||
* sm_scale
|
||||
)
|
||||
store_start = jnp.maximum(q_start - q_len_start, 0)
|
||||
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
|
||||
|
||||
@pl.when(kv_blk_idx == 0)
|
||||
def init_scratch_ref():
|
||||
masked_store(
|
||||
head_m_ref,
|
||||
jnp.full_like(head_m_ref, -jnp.inf),
|
||||
store_start,
|
||||
store_end,
|
||||
num_q_heads_per_kv_head,
|
||||
)
|
||||
masked_store(
|
||||
head_l_ref,
|
||||
jnp.zeros_like(head_l_ref),
|
||||
store_start,
|
||||
store_end,
|
||||
num_q_heads_per_kv_head,
|
||||
)
|
||||
masked_store(
|
||||
head_o_ref,
|
||||
jnp.zeros_like(head_o_ref),
|
||||
store_start,
|
||||
store_end,
|
||||
)
|
||||
|
||||
row_ids = (
|
||||
(kv_len - q_len)
|
||||
+ q_len_start
|
||||
- q_start
|
||||
+ jax.lax.broadcasted_iota(
|
||||
jnp.int32,
|
||||
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
||||
0,
|
||||
)
|
||||
// num_q_heads_per_kv_head
|
||||
)
|
||||
col_ids = kv_len_start + jax.lax.broadcasted_iota(
|
||||
jnp.int32,
|
||||
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
||||
1,
|
||||
)
|
||||
causal_mask = row_ids < col_ids
|
||||
qk += jnp.where(causal_mask, mask_value, 0.0)
|
||||
m_curr = jnp.max(qk, axis=1, keepdims=True)
|
||||
s_curr = jnp.exp(qk - m_curr)
|
||||
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
|
||||
lm_store_shape = head_m_ref.shape
|
||||
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
|
||||
l_curr = jnp.broadcast_to(
|
||||
s_curr.sum(axis=1, keepdims=True), lm_store_shape
|
||||
)
|
||||
m_prev = head_m_ref[...]
|
||||
l_prev = head_l_ref[...]
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
masked_store(
|
||||
head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head
|
||||
)
|
||||
alpha = jnp.exp(m_prev - m_next)
|
||||
beta = jnp.exp(m_curr - m_next)
|
||||
l_alpha = alpha * l_prev
|
||||
l_next = l_alpha + beta * l_curr
|
||||
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
|
||||
masked_store(
|
||||
head_l_ref,
|
||||
l_next_safe,
|
||||
store_start,
|
||||
store_end,
|
||||
num_q_heads_per_kv_head,
|
||||
)
|
||||
|
||||
def broadcast_to_shape(arr, shape):
|
||||
if arr.shape == shape:
|
||||
return arr
|
||||
assert len(arr.shape) == len(shape)
|
||||
assert arr.shape[0] == shape[0]
|
||||
assert shape[1] % arr.shape[1] == 0
|
||||
# no-op concatenation.
|
||||
return jnp.concatenate(
|
||||
[arr for _ in range(shape[1] // arr.shape[1])], axis=1
|
||||
)
|
||||
|
||||
o_curr = head_o_ref[...].reshape(-1, head_dim)
|
||||
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
|
||||
beta = broadcast_to_shape(beta, qkv.shape)
|
||||
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
|
||||
out = lax.div(
|
||||
l_alpha * o_curr + beta * qkv,
|
||||
l_next_safe,
|
||||
).astype(head_o_ref.dtype)
|
||||
masked_store(
|
||||
head_o_ref,
|
||||
out.reshape(head_o_ref.shape),
|
||||
store_start,
|
||||
store_end,
|
||||
)
|
||||
|
||||
def is_valid_kv_blk_in_cur_seq(kv_states):
|
||||
kv_blk_idx, _ = kv_states
|
||||
return kv_blk_idx * num_kv_per_blk < kv_len
|
||||
|
||||
def compute_with_kv_blk_in_cur_seq(kv_states):
|
||||
kv_blk_idx, cur_buf_idx = kv_states
|
||||
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
|
||||
get_next_prefetch_ids(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
)
|
||||
)
|
||||
|
||||
@pl.when(next_heads_blk_idx < num_heads_blks)
|
||||
def prefetch_next_kv_blk():
|
||||
# TODO(jevinjiang): reuse the same buffer if it is already prefetched!
|
||||
# TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
|
||||
# DMA to fixed size buffer!
|
||||
next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors(
|
||||
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
|
||||
)
|
||||
next_async_copy_k.start()
|
||||
next_async_copy_v.start()
|
||||
|
||||
cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
)
|
||||
kv_to_load_shape = (
|
||||
num_kv_pages_per_blk * page_size * num_kv_heads_per_blk,
|
||||
head_dim,
|
||||
)
|
||||
k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape)
|
||||
v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape)
|
||||
for kv_head_idx in range(num_kv_heads_per_blk):
|
||||
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
|
||||
# TODO(jevinjiang): extra handlig for packed type that can start at
|
||||
# unaligned position!
|
||||
q = fold_on_2nd_minor(
|
||||
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
|
||||
)
|
||||
k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk)
|
||||
v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk)
|
||||
flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
l_ref.at[kv_head_idx],
|
||||
m_ref.at[kv_head_idx],
|
||||
o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :],
|
||||
kv_blk_idx=kv_blk_idx,
|
||||
)
|
||||
return kv_blk_idx + 1, next_buf_idx
|
||||
|
||||
_, next_buf_idx = lax.while_loop(
|
||||
is_valid_kv_blk_in_cur_seq,
|
||||
compute_with_kv_blk_in_cur_seq,
|
||||
(0, cur_buf_idx), # (kv_blk_idx, buf_idx)
|
||||
)
|
||||
next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx)
|
||||
done = lax.select(q_end < q_len_end, done, 1)
|
||||
return done, next_seq_idx, next_buf_idx
|
||||
|
||||
_, seq_idx, buf_idx = lax.while_loop(
|
||||
is_cur_q_blk_needed,
|
||||
compute_with_cur_q_blk,
|
||||
(0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
|
||||
)
|
||||
# Reset seq_idx for next kv_heads_blk if run out of seqs!
|
||||
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
|
||||
seq_buf_idx_ref[1] = buf_idx
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
assert b != 0
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
if dtype == jnp.float32:
|
||||
return 1
|
||||
if dtype == jnp.bfloat16:
|
||||
return 2
|
||||
if dtype == jnp.int8:
|
||||
return 4
|
||||
if dtype == jnp.int4:
|
||||
return 8
|
||||
raise ValueError(f"Not implemented: unsupported {dtype=}")
|
||||
|
||||
|
||||
def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype):
|
||||
q_packing = get_dtype_packing(q_dtype)
|
||||
kv_packing = get_dtype_packing(kv_dtype)
|
||||
|
||||
def can_be_xla_fully_tiled(x, packing):
|
||||
if x % packing != 0:
|
||||
return False
|
||||
x //= packing
|
||||
return x in (1, 2, 4, 8) or x % 8 == 0
|
||||
|
||||
# TODO(jevinjiang): support unaligned number of heads!
|
||||
if not can_be_xla_fully_tiled(num_kv_heads, kv_packing):
|
||||
raise ValueError(
|
||||
f"Not implemented: {num_kv_heads=} can not be XLA fully tiled."
|
||||
)
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
ratio = num_q_heads // num_kv_heads
|
||||
# TODO(jevinjiang): we can choose smaller tiling for packed type if large
|
||||
# second minor tiling is not on.
|
||||
max_kv_tiling = 8 * kv_packing
|
||||
min_kv_heads = (
|
||||
max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads
|
||||
)
|
||||
min_q_heads = min_kv_heads * ratio
|
||||
if can_be_xla_fully_tiled(min_q_heads, q_packing):
|
||||
return min_q_heads, min_kv_heads
|
||||
return num_q_heads, num_kv_heads
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"sm_scale",
|
||||
"mask_value",
|
||||
"num_kv_pages_per_block",
|
||||
"num_queries_per_block",
|
||||
"vmem_limit_bytes",
|
||||
],
|
||||
)
|
||||
def ragged_paged_attention(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
# TODO(jevinjiang): create a write_to_kv_cache kernel!
|
||||
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1]
|
||||
*,
|
||||
sm_scale: float = 1.0,
|
||||
mask_value: float = DEFAULT_MASK_VALUE,
|
||||
num_kv_pages_per_block: int = 16,
|
||||
num_queries_per_block: int = 128,
|
||||
vmem_limit_bytes: int | None = None,
|
||||
):
|
||||
"""Ragged paged attention that supports mixed prefill and decode.
|
||||
|
||||
Args:
|
||||
q: concatenated all sequences' queries.
|
||||
k_pages: paged K cache. Normally in HBM.
|
||||
v_pages: paged V cache. Normally in HBM.
|
||||
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
||||
page_indices: the first index indicates which page to use in the kv cache
|
||||
for each sequence. Only the first num_seqs values are valid.
|
||||
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
||||
kv_lens, only the first num_seqs+1 values are valid.
|
||||
num_seqs: the dynamic number of sequences.
|
||||
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
||||
mask_value: mask value for causal mask.
|
||||
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
||||
attention block in the pallas kernel.
|
||||
num_queries_per_block: number of kv pages to be processed in one flash
|
||||
attention block in the pallas kernel.
|
||||
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
||||
|
||||
Returns:
|
||||
The output of the attention.
|
||||
"""
|
||||
check_inputs_shapes(
|
||||
q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs
|
||||
)
|
||||
_, num_q_heads, head_dim = q.shape
|
||||
_, page_size, num_kv_heads, _ = k_pages.shape
|
||||
num_q_per_blk = num_queries_per_block
|
||||
num_kv_pages_per_blk = num_kv_pages_per_block
|
||||
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk)
|
||||
num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk(
|
||||
num_q_heads, num_kv_heads, q.dtype, k_pages.dtype
|
||||
)
|
||||
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
|
||||
num_heads_blks = num_q_heads // num_q_heads_per_blk
|
||||
grid = (num_heads_blks, num_q_blks)
|
||||
|
||||
def q_index_map(heads_blk_idx, q_blk_idx, *_):
|
||||
return (q_blk_idx, heads_blk_idx, 0)
|
||||
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(num_q_per_blk, num_q_heads_per_blk, head_dim),
|
||||
q_index_map,
|
||||
)
|
||||
in_specs = [
|
||||
q_block_spec,
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
out_specs = q_block_spec
|
||||
lm_scratch = pltpu.VMEM(
|
||||
# TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
|
||||
# unaligned slicing!
|
||||
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
|
||||
jnp.float32,
|
||||
)
|
||||
double_buf_scratch = pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
num_kv_pages_per_blk,
|
||||
page_size,
|
||||
num_kv_heads_per_blk,
|
||||
head_dim,
|
||||
),
|
||||
k_pages.dtype,
|
||||
)
|
||||
scratch_shapes = [
|
||||
double_buf_scratch, # k_bufs
|
||||
double_buf_scratch, # v_bufs
|
||||
pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem]
|
||||
lm_scratch, # l_ref
|
||||
lm_scratch, # m_ref
|
||||
]
|
||||
scalar_prefetches = (
|
||||
kv_lens,
|
||||
page_indices,
|
||||
cu_q_lens,
|
||||
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
|
||||
num_seqs,
|
||||
)
|
||||
kernel = pl.pallas_call(
|
||||
functools.partial(
|
||||
ragged_paged_attention_kernel,
|
||||
sm_scale=sm_scale,
|
||||
mask_value=mask_value,
|
||||
),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=grid,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=(
|
||||
"arbitrary",
|
||||
"arbitrary",
|
||||
),
|
||||
vmem_limit_bytes=vmem_limit_bytes,
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32),
|
||||
name="ragged_paged_attention_kernel",
|
||||
)
|
||||
# TODO(jevinjiang): Use f32 acc scratch for output! So we only need
|
||||
# to transfer output with desired dtype back to HBM.
|
||||
return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype)
|
@ -55,6 +55,57 @@ for prim in it.chain(
|
||||
roofline.register_standard_roofline(prim)
|
||||
|
||||
|
||||
def _unary_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=x.size,
|
||||
unfused_hbm_bytes=(
|
||||
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
roofline.register_roofline(lax.abs_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.acos_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.asin_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.atan_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.ceil_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.conj_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cos_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cosh_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.exp_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.expm1_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.floor_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.imag_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.log_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.log1p_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.logistic_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.neg_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.not_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.real_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.round_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sign_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sin_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sinh_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.square_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.tan_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.digamma_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erf_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erfc_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.lgamma_p)(_unary_p_roofline)
|
||||
|
||||
def _binary_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
|
@ -544,6 +544,8 @@ def _shard_map_staging(
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
|
||||
# TODO add underscore version, for direct-linearize to consume
|
||||
|
||||
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval
|
||||
@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
out_avals_ = [x.aval for x in jaxpr.outvars]
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
|
||||
in_avals_, in_nodes)
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(
|
||||
mesh, frozenset(mesh.axis_names) - auto
|
||||
)
|
||||
manual_axes = frozenset(mesh.axis_names) - auto
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
with _extend_axis_env(mesh, auto):
|
||||
out_nodes_, tokens_out = mlir.call_lowering(
|
||||
@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool,
|
||||
|
||||
def _match(mesh, check_rep, pspec, x):
|
||||
src = P(mesh.axis_names)
|
||||
# TODO put back (?) needed for rep checking in eager? for now test rewrite
|
||||
return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x)
|
||||
|
||||
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
|
||||
@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace):
|
||||
__slots__ = ("mesh", "auto", "check", "context_mesh")
|
||||
|
||||
mesh: Mesh
|
||||
auto: frozenset[AxisName]
|
||||
check: bool
|
||||
context_mesh: AbstractMesh
|
||||
|
||||
@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace):
|
||||
if isinstance(val, ShardMapTracer):
|
||||
return val.val, val.rep
|
||||
elif isinstance(val, Tracer):
|
||||
raise Exception("Shouldn't have any non-shard_map tracers")
|
||||
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
|
||||
else:
|
||||
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
|
||||
return val_, None
|
||||
@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
|
||||
f.debug_info)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
|
||||
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
res_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
@as_hashable_function(closure=linearize_outs_thunk)
|
||||
def primal_out_names_thunk():
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
out_names = out_names_thunk()
|
||||
# This is incorrect so we set `check_rep=False` as we do in the JVP rule.
|
||||
return (*({0: all_names} for _ in residual_avals), *out_names)
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
# This is incorrect so we set `check_rep=False` in the tangent (as in JVP).
|
||||
return (*({0: res_names} for _ in range(num_res_out)), *out_names)
|
||||
primal_params = dict(
|
||||
mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
all_primal_results = shard_map_p.bind_with_trace(
|
||||
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals]
|
||||
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
||||
trace.parent_trace, (f_primal, *primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
non_fwd_res = all_primal_results[:num_res_out]
|
||||
primals_out = all_primal_results[num_res_out:]
|
||||
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
||||
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
|
||||
for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)]
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
||||
out_names = out_names_thunk()
|
||||
new_in_names = (*({0: all_names} for _ in residual_avals),
|
||||
residual_names = [in_names[f1] if f1 is not None else
|
||||
out_names[f2] if f2 is not None else
|
||||
{0: res_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
new_in_names = (*residual_names, *({} for _ in range(len(env))),
|
||||
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
|
||||
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
|
||||
new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz)
|
||||
@as_hashable_function(closure=(new_out_names))
|
||||
def tangent_out_names_thunk():
|
||||
return new_out_names
|
||||
@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
out_names_thunk=tangent_out_names_thunk, check_rep=False,
|
||||
rewrite=rewrite, auto=auto)
|
||||
|
||||
# TODO TODO don't round-trip
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents)
|
||||
return core.eval_jaxpr(lin_jaxpr, (), *args)
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
|
||||
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *nz_tangents_in), tangent_params)
|
||||
*residuals, *env, *nz_tangents_in), tangent_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
|
||||
ans = f(*args, **kwargs)
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = ans[:num_residuals]
|
||||
primals = ans[num_residuals:]
|
||||
residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in residuals)
|
||||
return residuals + primals
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
residuals = ans[:num_res_out]
|
||||
primals = ans[num_res_out:]
|
||||
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in residuals]
|
||||
return *residuals, *primals
|
||||
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
|
||||
@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym()
|
||||
params_known, params_staged, all_names = _pe_custom_params(
|
||||
params_known, params_staged, res_names = _pe_custom_params(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
|
||||
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval))
|
||||
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
mesh = params_known['mesh']
|
||||
auto = params_known['auto']
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto)
|
||||
res_names_ = _all_newly_manual_mesh_names(mesh, auto)
|
||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||
out_names_known = out_names_known + [{0: all_names}] * sum(which)
|
||||
out_names_known = out_names_known + [{0: res_names_}] * sum(which)
|
||||
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
||||
out_names=tuple(out_names_known))
|
||||
|
||||
@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
||||
res_names = [in_names_known[f1] if f1 is not None else
|
||||
out_names_known[f2] if f2 is not None else
|
||||
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
{0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
in_names_staged = res_names + in_names_staged
|
||||
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
||||
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
||||
out_names=tuple(out_names_staged), check_rep=False)
|
||||
return new_params_known, new_params_staged, all_names
|
||||
return new_params_known, new_params_staged, res_names_
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_mesh_names_except_spmd(
|
||||
@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd(
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto)
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_newly_manual_mesh_names(
|
||||
mesh: Mesh, auto: frozenset[AxisName], trace=None
|
||||
) -> tuple[AxisName, ...]:
|
||||
axis_env = core.get_axis_env()
|
||||
spmd_names = axis_env.spmd_axis_names
|
||||
axis_sizes = axis_env.axis_sizes
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto and name not in axis_sizes)
|
||||
if not (ctx_mesh := get_abstract_mesh()).empty:
|
||||
del mesh
|
||||
already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ()))
|
||||
return tuple(name for name in ctx_mesh.axis_names
|
||||
if name not in auto | already_manual_names)
|
||||
else:
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
axis_env = core.get_axis_env()
|
||||
vmap_spmd_names = set(axis_env.spmd_axis_names)
|
||||
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
|
||||
return tuple(name for name in mesh.axis_names
|
||||
if name not in auto | vmap_spmd_names | already_manual_names)
|
||||
|
||||
# DCE
|
||||
|
||||
|
@ -19,8 +19,18 @@ import math
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import ffi
|
||||
from jax._src import util
|
||||
from jax._src.typing import Array
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
||||
|
||||
if hasattr(gpu_sparse, "registrations"):
|
||||
for platform, targets in gpu_sparse.registrations().items():
|
||||
for name, value, api_version in targets:
|
||||
ffi.register_ffi_target(
|
||||
name, value, platform=platform, api_version=api_version
|
||||
)
|
||||
|
||||
|
||||
class JAXSparse(util.StrictABC):
|
||||
|
@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||
|
||||
|
||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars, name,
|
||||
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
||||
@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
from jax._src.lax.lax import (
|
||||
DotDimensionNumbers as DotDimensionNumbers,
|
||||
RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
|
||||
Precision as Precision,
|
||||
PrecisionLike as PrecisionLike,
|
||||
DotAlgorithm as DotAlgorithm,
|
||||
@ -158,6 +159,7 @@ from jax._src.lax.lax import (
|
||||
pow as pow,
|
||||
pow_p as pow_p,
|
||||
ragged_dot as ragged_dot,
|
||||
ragged_dot_general as ragged_dot_general,
|
||||
real as real,
|
||||
real_p as real_p,
|
||||
reciprocal as reciprocal,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user