Merge pull request #271 from ROCm/ci-upstream-sync-142_1

CI: 03/11/25 upstream sync
This commit is contained in:
github-actions[bot] 2025-03-11 14:10:03 -05:00 committed by GitHub
commit 6ee76a8a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
147 changed files with 6365 additions and 2088 deletions

View File

@ -54,6 +54,12 @@ build:macos --apple_platform_type=macos
build:macos --linkopt=-Wl,-undefined,dynamic_lookup
build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup
# Use cc toolchains from apple_support for Apple builds.
# https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup
build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain
build:macos --crosstool_top=@local_config_apple_cc//:toolchain
build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain
# Windows has a relatively short command line limit, which JAX has begun to hit.
# See https://docs.bazel.build/versions/main/windows.html
build:windows --features=compiler_param_file

View File

@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: 'd982a6245400295477f5da5afa1c4a2a5e641ea4' # Latest commit as of 2025-01-30
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}

View File

@ -29,11 +29,6 @@ on:
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
@ -62,7 +57,6 @@ jobs:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -88,7 +82,7 @@ jobs:
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download jaxlib wheel from GCS (non-Windows runs)
- name: Download wheels from GCS (non-Windows runs)
id: download-wheel-artifacts-nw
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
@ -96,14 +90,10 @@ jobs:
continue-on-error: true
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: |
mkdir -p $(pwd)/dist &&
mkdir -p $(pwd)/dist
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
fi
- name: Download jaxlib wheel from GCS (Windows runs)
- name: Download wheels from GCS (Windows runs)
id: download-wheel-artifacts-w
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in step below so that we can print a more
@ -115,12 +105,8 @@ jobs:
mkdir dist
@REM Use `call` so that we can run sequential gsutil commands on Windows
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if not "${{ inputs.install-jax-current-commit }}"=="1" (
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
)
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
run: |

View File

@ -34,11 +34,6 @@ on:
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
@ -66,7 +61,6 @@ jobs:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -86,7 +80,7 @@ jobs:
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download the wheel artifacts from GCS
- name: Download wheels from GCS
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the next step so that we can print a more
@ -94,14 +88,10 @@ jobs:
continue-on-error: true
run: |
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |

View File

@ -26,8 +26,6 @@ race_top:PyMember_GetOne
# https://github.com/python/cpython/issues/129547
race:type_get_annotations
# https://github.com/python/cpython/issues/130547
race:split_keys_entry_added
# https://github.com/python/cpython/issues/129748
race:mi_block_set_nextx
@ -64,3 +62,6 @@ race:gemm_oncopy
# https://github.com/python/cpython/issues/130571
# race:_PyObject_GetMethod
# https://github.com/python/cpython/issues/130547
# race:split_keys_entry_added

View File

@ -35,7 +35,7 @@ jobs:
apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev
libffi-dev liblzma-dev file zip
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax

View File

@ -27,6 +27,16 @@ concurrency:
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
build-jax-artifact:
uses: ./.github/workflows/build_artifacts.yml
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n2-16"
artifact: "jax"
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-jaxlib-artifact:
uses: ./.github/workflows/build_artifacts.yml
strategy:
@ -66,7 +76,7 @@ jobs:
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: build-jaxlib-artifact
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
@ -80,7 +90,6 @@ jobs:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
install-jax-current-commit: 1
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-cuda:
@ -88,7 +97,7 @@ jobs:
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jaxlib-artifact, build-cuda-artifacts]
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
@ -111,7 +120,6 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
install-jax-current-commit: 1
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

View File

@ -40,9 +40,6 @@ jobs:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-cuda:
@ -61,7 +58,4 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
load(
"//jaxlib:jax.bzl",
"jax_wheel",

View File

@ -23,6 +23,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
## jax 0.5.2 (Mar 4, 2025)
Patch release of 0.5.1
* Bug fixes
* Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1
## jax 0.5.1 (Feb 24, 2025)
* New Features
@ -54,6 +61,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
* Bug fixes
* TPU runtime startup and shutdown time should be significantly improved on
@ -169,8 +178,6 @@ to signify this.
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
## jax 0.4.37
* Bug fixes
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
* Fix a bug that will throw `index out of range` error in

View File

@ -70,7 +70,7 @@ jax_python_wheel_repository(
)
load(
"@tsl//third_party/py:python_wheel.bzl",
"@xla//third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
@ -78,7 +78,7 @@ python_wheel_version_suffix_repository(
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
@ -90,7 +90,7 @@ load(
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
@ -104,21 +104,21 @@ cudnn_redist_init_repository(
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)

View File

@ -475,7 +475,7 @@ def bench_pjit_check_aval_sharding(state):
aval = jax.core.ShapedArray((8, 2), np.int32)
while state:
pjit_check_aval_sharding([s] * 100, [aval] * 100, None, 'benchmark', False)
pjit_check_aval_sharding([s] * 100, [aval] * 100, [''] * 100, 'benchmark', False)
@google_benchmark.register

View File

@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = {
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
}
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
# new wheel build rule fully, the build CLI will switch to the new wheel build
# rule as the default.
WHEEL_BUILD_TARGET_DICT_NEW = {
"jax": "//:jax_wheel",
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
}
def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--use_new_wheel_build_rule",
action="store_true",
help=
"""
Whether to use the new wheel build rule. Temporary flag and will be
removed once JAX migrates to the new wheel build rule fully.
""",
)
parser.add_argument(
"--editable",
action="store_true",
@ -386,7 +407,10 @@ async def main():
for option in args.bazel_startup_options:
bazel_command_base.append(option)
bazel_command_base.append("run")
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
bazel_command_base.append("run")
else:
bazel_command_base.append("build")
if args.python_version:
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
@ -592,13 +616,19 @@ async def main():
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.use_new_wheel_build_rule:
logging.info("Using new wheel build rule")
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
else:
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
@ -611,7 +641,7 @@ async def main():
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
if wheel not in wheel_build_targets.keys():
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
@ -629,32 +659,33 @@ async def main():
)
# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
build_target = wheel_build_targets[wheel]
wheel_build_command.append(build_target)
wheel_build_command.append("--")
if not args.use_new_wheel_build_rule:
wheel_build_command.append("--")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
# Exit with error if any wheel build fails.

View File

@ -18,5 +18,4 @@ ml_dtypes>=0.4.0
opt_einsum
zstandard
etils[epath]
# TODO(ybaturina): remove setuptools version
setuptools<71.0.0
setuptools

View File

@ -634,9 +634,9 @@ zstandard==0.22.0 \
# via -r build/requirements.in
# The following packages are considered to be unsafe in a requirements file:
setuptools==69.5.1 \
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
setuptools==76.0.0 \
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
# via
# -r build/requirements.in
# -r build/test-requirements.txt

View File

@ -623,9 +623,9 @@ zstandard==0.22.0 \
# via -r build/requirements.in
# The following packages are considered to be unsafe in a requirements file:
setuptools==69.5.1 \
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
setuptools==76.0.0 \
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
# via
# -r build/requirements.in
# -r build/test-requirements.txt

View File

@ -623,9 +623,9 @@ zstandard==0.22.0 \
# via -r build/requirements.in
# The following packages are considered to be unsafe in a requirements file:
setuptools==69.5.1 \
--hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \
--hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32
setuptools==76.0.0 \
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
# via
# -r build/requirements.in
# -r build/test-requirements.txt

View File

@ -747,9 +747,9 @@ zstandard==0.23.0 \
# via -r build/requirements.in
# The following packages are considered to be unsafe in a requirements file:
setuptools==70.3.0 \
--hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \
--hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc
setuptools==76.0.0 \
--hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \
--hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4
# via
# -r build/requirements.in
# -r build/test-requirements.txt

View File

@ -12,8 +12,7 @@ portpicker; python_version<"3.13"
pytest-xdist
wheel
rich
# TODO(ybaturina): remove setuptools version
setuptools<71.0.0
setuptools
# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement
# below.
matplotlib~=3.8.4; python_version=="3.10"

View File

@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str):
return major_version
def get_jax_configure_bazel_options(bazel_command: list[str]):
def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
# we find the index of "run" and filter everything after it.
start = bazel_command.index("run")
# we find the index of "run" and filter everything after it. If we are using
# the new wheel build rule, we will find the index of "build" instead.
if use_new_wheel_build_rule:
start = bazel_command.index("build")
else:
start = bazel_command.index("run")
jax_configure_bazel_options = ""
try:
for i in range(start + 1, len(bazel_command)):

View File

@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
arch="amd64"
fi
# Determine the artifact tag flags based on the artifact type. A release
# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is
# tagged with the release version and a nightly suffix that contains the
# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with
# the git commit hash of the HEAD of the current branch and the date of the
# commit (e.g. 0.5.1.dev20250128+3e75e20c7).
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
current_date=$(date +%Y%m%d)
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
else
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
exit 1
fi
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# Build the jax artifact
if [[ "$artifact" == "jax" ]]; then
python -m build --outdir $JAXCI_OUTPUT_DIR
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
else
bazelrc_config="ci_${bazelrc_config}"
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazelrc_config="ci_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
fi
# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi
# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi
# Build the artifact.
# Build the artifact.
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags
# If building release artifacts, we also build a release candidate ("rc")
# tagged wheel.
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
fi
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]]; then
./ci/utilities/run_auditwheel.sh
fi
# Move the built artifacts from the Bazel cache directory to the output
# directory.
if [[ "$artifact" == "jax" ]]; then
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
else
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
fi
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
./ci/utilities/run_auditwheel.sh
fi
else

View File

@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
# flag is enabled only for CI builds.
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
# Type of artifacts to build. Valid values are "default", "release", "nightly".
# This affects the wheel naming/tag.
export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"}
# When building release artifacts, we build a release candidate wheel ("rc"
# tagged wheel) in addition to the release wheel. This environment variable
# sets the version of the release candidate ("RC") artifact to build.
export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-}
# #############################################################################
# Test script specific environment variables.
# #############################################################################
@ -65,9 +74,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
# on the system. By default, it is set to match the version of the hermetic
# Python used by Bazel for building the wheels.
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
# Installs the JAX package in editable mode at the current commit. Enabled by
# default. Nightly/Release builds disable this flag in the Github action
# workflow files.
export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"}
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}

View File

@ -19,7 +19,7 @@
# avoid using the Windows version of `find` on Msys.
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
if [[ -z "$WHEELS" ]]; then
if [[ -z "${WHEELS[@]}" ]]; then
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
exit 1
fi
@ -38,10 +38,4 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then
"$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}")
else
"$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}"
fi
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
echo "Installing the JAX package at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m uv pip install .
fi

View File

@ -98,3 +98,6 @@ function retry {
# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
retry "bazel --version"
# Create the output directory if it doesn't exist.
mkdir -p "$JAXCI_OUTPUT_DIR"

View File

@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly,
and elsewhere.
At the heart of the project is the [JAX
core](http://github.com/google/jax) library, which focuses on the
core](http://github.com/jax-ml/jax) library, which focuses on the
fundamentals of machine learning and numerical computing, at scale.
When [developing](#development) the core, we want to maintain agility

View File

@ -91,8 +91,8 @@ def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0

View File

@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
new_params = jax.tree.map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params

View File

@ -69,7 +69,7 @@ def dots_saveable(prim, *_, **__) -> bool:
lax_convolution.conv_general_dilated_p}
checkpoint_dots = dots_saveable
def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
# This is a useful heuristic for transformers.
if prim is lax_internal.dot_general_p:
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
@ -160,8 +160,8 @@ checkpoint_policies = types.SimpleNamespace(
nothing_saveable=nothing_saveable,
dots_saveable=dots_saveable,
checkpoint_dots=dots_saveable,
dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable,
dots_with_no_batch_dims_saveable=dots_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=dots_with_no_batch_dims_saveable,
offload_dot_with_no_batch_dims=offload_dot_with_no_batch_dims,
save_anything_except_these_names=save_anything_except_these_names,
save_any_names_but_these=save_any_names_but_these,
@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args):
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` can only take integer values greater than or "
"equal to `-len(args)` and less than `len(args)`, but got "
f"{static_argnums}")
f"{static_argnums}, while `len(args)` = {len(args)}")
if not static_argnums:
return fun, args

View File

@ -1094,6 +1094,9 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
return f"args{keystr(key_path)}"
# args is a tuple, so key_path[0].idx is the index into args.
i = key_path[0].idx
# This can happen with star arguments (*args)
if i >= len(signature_parameters):
return f"args{keystr(key_path)}"
res = f"argument {signature_parameters[i]}"
if len(key_path) > 1:
res += keystr(key_path[1:])
@ -1135,8 +1138,8 @@ def pmap(
fun: Callable,
axis_name: AxisName | None = None,
*,
in_axes=0,
out_axes=0,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
static_broadcasted_argnums: int | Iterable[int] = (),
devices: Sequence[xc.Device] | None = None, # noqa: F811
backend: str | None = None,
@ -2002,8 +2005,8 @@ def vjp(
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
wrapped_fun = lu.wrap_init(fun,
debug_info=debug_info("vjp", fun, primals, {}))
wrapped_fun = lu.wrap_init(
fun, debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):

View File

@ -382,6 +382,9 @@ def is_hashable(arg):
return False
SENTINEL = object()
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
@ -389,7 +392,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
@ -564,8 +567,9 @@ def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]:
passed_kwargs = [k for k in ba.kwargs if k in kwargs]
if passed_kwargs:
raise TypeError(
f"keyword arguments ({passed_kwargs}) could not be resolved to "
"positions")
"The following keyword arguments could not be resolved to positions: "
f"{', '.join(passed_kwargs)}"
)
return ba.args

View File

@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings,
in_layouts, out_layouts,
resource_env, donated_invars, name, inline, keep_unused,
donated_invars, ctx_mesh, name, inline, keep_unused,
compiler_options_kvs):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
out_shardings=new_out_shardings,
in_layouts=new_in_layouts,
out_layouts=new_out_layouts,
resource_env=resource_env,
donated_invars=new_donated_invars,
ctx_mesh=ctx_mesh,
name=name,
inline=inline,
keep_unused=keep_unused,

View File

@ -15,6 +15,7 @@
import datetime
import os
import re
import warnings
from jax import version
from jax._src import config
from jax._src import hardware_utils
@ -72,7 +73,19 @@ def cloud_tpu_init() -> None:
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
libtpu_path = get_tpu_library_path()
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id()
if (
tpu_id is not None
and tpu_id >= hardware_utils.TpuVersion.v5e
and not hardware_utils.transparent_hugepages_enabled()
):
warnings.warn(
'Transparent hugepages are not enabled. TPU runtime startup and'
' shutdown time should be significantly improved on TPU v5e and newer.'
' If not already set, you may need to enable transparent hugepages in'
' your VM image (sudo sh -c "echo always >'
' /sys/kernel/mm/transparent_hugepage/enabled")'
)
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
return

View File

@ -15,8 +15,8 @@
from __future__ import annotations
import abc
import pathlib
from jax._src import path as pathlib
from jax._src import util

View File

@ -1815,6 +1815,7 @@ def _make_lengths_same(sharding, ndim):
if ndim > len(sharding.spec):
return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim))
if ndim < len(sharding.spec):
assert all(s is None for s in sharding.spec[ndim:])
return sharding.with_spec(sharding.spec[:ndim])
assert False, "unreachable"
@ -1840,6 +1841,8 @@ def _maybe_modify_sharding(sharding, ndim):
return sharding
if sharding.mesh._are_all_axes_explicit:
if ndim > len(sharding.spec):
return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim))
return sharding
out = sharding.with_spec(modify_spec_for_auto_manual(
@ -1849,9 +1852,22 @@ def _maybe_modify_sharding(sharding, ndim):
out = _make_lengths_same(out, ndim)
return out
def _check_divisibility(sharding, shape):
mesh = sharding.mesh
for dim, (spec, sh) in enumerate(zip(sharding.spec, shape)):
if spec is None:
continue
spec = spec if isinstance(spec, tuple) else (spec,)
size = math.prod(mesh.shape[s] for s in spec)
_, remainder = divmod(sh, size)
if remainder != 0:
raise ValueError(
f"Sharding spec {spec} implies that array axis {dim} is partitioned"
f" {size} times, but does not evenly divide the dimension size {sh}."
f" Got shape: {shape} and sharding {sharding}")
@cache(max_size=4096, trace_context_in_key=True)
def get_sharding(sharding, ndim):
def get_sharding(sharding, shape):
"""Modifies and checks the sharding.
Some modifications/checks include:
@ -1860,6 +1876,7 @@ def get_sharding(sharding, ndim):
* Checking for len(spec)-ndim match
* Checking if the mesh is an AbstractMesh.
"""
ndim = len(shape)
if sharding is None:
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
@ -1871,6 +1888,7 @@ def get_sharding(sharding, ndim):
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
_check_divisibility(out_s, shape)
return out_s
@ -1882,7 +1900,7 @@ class ShapedArray(UnshapedArray):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.sharding = get_sharding(sharding, len(self.shape))
self.sharding = get_sharding(sharding, self.shape)
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
@ -2489,8 +2507,8 @@ class MapPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
subfun = lu.hashable_partial(
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params

View File

@ -333,7 +333,7 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
def check_is_flash_attention(
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False,
is_fp8=False):
# Extract sequence length (T) and head dim (H) based on layout
if layout == AttentionLayout.BNTH.value:

View File

@ -143,7 +143,15 @@ class custom_vmap:
def __call__(self, *args, **kwargs):
debug_fun = api_util.debug_info("custom_vmap fun", self.fun,
args, kwargs)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
try:
args = api_util.resolve_kwargs(self.fun, args, kwargs)
except TypeError as e:
raise TypeError(
"The input arguments to the custom_vmap-decorated function "
f"{debug_fun.func_name} could not be resolved to positional-only "
f"arguments. Binding failed with the error:\n{e}"
) from e
if not self.vmap_rule:
raise AttributeError(
f"No batching rule defined for custom_vmap function {debug_fun.func_name} "

View File

@ -133,7 +133,15 @@ class custom_dce:
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
try:
args = api_util.resolve_kwargs(self.fun, args, kwargs)
except TypeError as e:
raise TypeError(
"The input arguments to the custom_dce-decorated function "
f"{debug.func_name} could not be resolved to positional-only "
f"arguments. Binding failed with the error:\n{e}"
) from e
if self.static_argnums:
static_argnums = set(self.static_argnums)
for i in static_argnums:

View File

@ -250,7 +250,15 @@ class custom_jvp(Generic[ReturnValue]):
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
args = resolve_kwargs(self.fun, args, kwargs)
try:
args = resolve_kwargs(self.fun, args, kwargs)
except TypeError as e:
raise TypeError(
"The input arguments to the custom_jvp-decorated function "
f"{primal_name} could not be resolved to positional-only arguments. "
f"Binding failed with the error:\n{e}"
) from e
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
@ -634,7 +642,16 @@ class custom_vjp(Generic[ReturnValue]):
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp."
raise AttributeError(msg)
args = resolve_kwargs(self.fun, args, kwargs)
try:
args = resolve_kwargs(self.fun, args, kwargs)
except TypeError as e:
raise TypeError(
"The input arguments to the custom_vjp-decorated function "
f"{debug_fun.func_name} could not be resolved to positional-only "
f"arguments. Binding failed with the error:\n{e}"
) from e
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
static_argnums=self.nondiff_argnums)
# TODO(necula): figure out how to construct the debug_bwd args
@ -1238,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.
return False

View File

@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
*tiled_args
)
if closed_jaxpr.out_avals != tiled_results:
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
[(t.shape, t.dtype) for t in tiled_results]):
raise ValueError(
"Mismatch in result shapes. %s vs %s"
% (repr(closed_jaxpr.out_avals), repr(tiled_results))

View File

@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
# fp4 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float4_e2m1fn: type[np.generic] | None = None
_float4_e2m1fn_dtype: np.dtype | None = None
def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
typ = np.dtype(dtype).type
@ -144,6 +150,8 @@ _float8_dtypes = [
_float8_e5m2fnuz_dtype,
]
_float4_dtypes: list[np.dtype] = []
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"):
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
if hasattr(ml_dtypes, "float4_e2m1fn"):
float4_e2m1fn = ml_dtypes.float4_e2m1fn
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
# 2-bit integer support
int2: type[np.generic] | None = None
@ -716,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
"implicit promotion. If you'd like your inputs to be promoted to another type, "
"you can do so explicitly using e.g. x.astype('float32')")
elif any(n in _float4_dtypes for n in nodes):
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. To avoid unintended promotion, 4-bit floats do not support "
"implicit promotion. If you'd like your inputs to be promoted to another type, "
"you can do so explicitly using e.g. x.astype('float32')")
elif any(n in _intn_dtypes for n in nodes):
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "

View File

@ -33,12 +33,11 @@ class JaxValueError(ValueError):
"""Exception raised for failed runtime error checks in JAX."""
#: The default error code for no error.
#:
#: This value is chosen because we can use `jnp.min()` to obtain the
#: first error when performing reductions.
_NO_ERROR = jnp.iinfo(jnp.uint32).max
"""The default error code for no error.
We choose this value because when performing reductions, we can use `min` to
obtain the smallest error code.
"""
_error_list_lock = threading.Lock()
@ -62,7 +61,7 @@ def _initialize_error_code_ref() -> None:
def set_error_if(pred: jax.Array, msg: str) -> None:
"""Set error if pred is true.
"""Set error if any element of pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
@ -74,7 +73,7 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
traceback = source_info_util.current().traceback
assert traceback is not None
with _error_list_lock:
new_error_code = len(_error_list)
new_error_code = jnp.uint32(len(_error_list))
_error_list.append((msg, traceback))
pred = pred.any()
@ -86,18 +85,26 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
def raise_if_error() -> None:
"""Raise error if an error is set."""
if _error_storage.ref is None:
return # if not initialized, do nothing
"""Raise error if an error is set.
This function should be called after the computation is finished. It should
not be called within a traced context, such as within a jitted function."
"""
if _error_storage.ref is None: # if not initialized, do nothing
return
error_code = _error_storage.ref[...]
if isinstance(error_code, core.Tracer):
raise ValueError(
"raise_if_error() should not be called within a traced context, such as"
" within a jitted function."
)
if error_code == jnp.uint32(_NO_ERROR):
return
try:
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)
traceback = traceback.as_python_traceback()
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)
finally:
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)
traceback = traceback.as_python_traceback()
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)

View File

@ -75,6 +75,7 @@ enum DType: byte {
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
f4_e2m1fn = 26,
}
table AbstractValue {

View File

@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
if dtypes._float4_e2m1fn_dtype is not None:
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}

View File

@ -62,6 +62,7 @@ class DType(object):
f8_e5m2fnuz = 21
f0 = 22
f8_e8m0fnu = 25
f4_e2m1fn = 26
class ShardingKind(object):

View File

@ -12,25 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import os
import pathlib
import glob
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
_TPU_PCI_DEVICE_IDS = [
# TPU v2, v3
'0x0027',
# No public name (plc)
'0x0056',
# TPU v4
'0x005e',
# TPU v5p
'0x0062',
# TPU v5e
'0x0063',
# TPU v6e
'0x006f',
]
_NVIDIA_GPU_DEVICES = [
'/dev/nvidia0',
@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [
'/dev/dxg', # WSL2
]
class TpuVersion(enum.IntEnum):
# TPU v2, v3
v2 = 0
v3 = 1
# No public name (plc)
plc = 2
# TPU v4
v4 = 3
# TPU v5p
v5p = 4
# TPU v5e
v5e = 5
# TPU v6e
v6e = 6
_TPU_PCI_DEVICE_IDS = {
'0x0027': TpuVersion.v3,
'0x0056': TpuVersion.plc,
'0x005e': TpuVersion.v4,
'0x0062': TpuVersion.v5p,
'0x0063': TpuVersion.v5e,
'0x006f': TpuVersion.v6e,
}
def num_available_tpu_chips_and_device_id():
"""Returns the device id and number of TPU chips attached through PCI."""
num_chips = 0
device_id = ''
tpu_version = None
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
vendor_id = pathlib.Path(vendor_path).read_text().strip()
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id():
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
device_id = pathlib.Path(device_path).read_text().strip()
if device_id in _TPU_PCI_DEVICE_IDS:
tpu_version = _TPU_PCI_DEVICE_IDS[device_id]
num_chips += 1
return num_chips, device_id
return num_chips, tpu_version
def has_visible_nvidia_gpu() -> bool:
"""True if there's a visible nvidia gpu available on device, False otherwise."""
return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
def transparent_hugepages_enabled() -> bool:
# See https://docs.kernel.org/admin-guide/mm/transhuge.html for more
# information about transparent huge pages.
path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled')
return path.exists() and path.read_text().strip() == '[always] madvise never'

View File

@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
partition_list)
partition_list, subs_list2)
zip = safe_zip
map = safe_map
@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info)
tangent_trace.tag = _tag
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
_store.store((residual_avals, nzs_out, jaxpr))
return tuple(consts) + tuple(out_primals)
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
getattr(c._trace, 'tag', None) is _tag) for c in consts]
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
res, env = partition_list(which_env, consts)
residual_avals = map(get_aval, res)
# Which residuals are just forwarded inputs? Check object id.
id_map = {id(p): i for i, p in enumerate(primals)}
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
# Which residuals are already primal outputs? Check object id.
id_map = {id(p): i for i, p in enumerate(out_primals)}
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
# Prune residuals not to include forwarded primal inputs or outputs.
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
_store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd))
return *res, *out_primals
@lu.transformation2
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
@ -157,6 +170,7 @@ def _linearize_jaxpr(
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
tangent_trace.tag = lin_trace.tag
def new_arg(trace, primal_aval, nz):
primal = primal_trace.new_arg(primal_aval)
@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun,
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tangent_trace.tag = linearize_trace.tag
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with (core.set_current_trace(linearize_trace, check_leaks=True),
@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
jaxpr, [True] * len(jaxpr.outvars),
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
consts = [c for c, used in zip(consts, used_consts) if used]
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
for t, nz in zip(out_tangents, out_nzs)]
@ -330,12 +349,34 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
# forces primal_in to contain UndefinedPrimals for tangent values!
map(write_primal, jaxpr.invars, primals_in)
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
# only operate on primals. This is required to support primitives with
# linearization rules that include computations on the residuals.
lin_eqns = []
for eqn in jaxpr.eqns:
# TODO (dfm): The effects check is probably stricter than necessary.
# Consider adding an allowlist of effects here.
if jaxpr.effects or any(
type(x) is not Literal and x not in primal_env for x in eqn.invars):
lin_eqns.append(eqn)
continue
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback
with source_info_util.user_context(
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
ct_env: dict[Any, Any] = {}
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
else contextlib.nullcontext())
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in lin_eqns[::-1]:
if eqn.primitive.ref_primitive:
if eqn.primitive is core.mutable_array_p:
val_var, = eqn.invars
@ -586,7 +627,7 @@ def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()
tangent_aval = get_aval(tangent).strip_weak_type()
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
@ -641,6 +682,7 @@ class LinearizeTrace(Trace):
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
dict(symbolic_zeros=symbolic_zeros))
@partial(lu.wrap_init, debug_info=f_jvp.debug_info)
def _f_jvp(primals, tangents):
outs = f_jvp.call_wrapped(*primals, *tangents)
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
@ -651,7 +693,7 @@ class LinearizeTrace(Trace):
nonzeros_in = [type(t) is not Zero for t in tangents_in]
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
f_jvp.debug_info, primals_in, {})
primals_in, {})
with core.set_current_trace(self.tangent_trace):
tangents_out = linearized(residuals, *tangents_in)
@ -690,53 +732,64 @@ class LinearizeTrace(Trace):
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not Zero for t in tangents)
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
f.debug_info)
f_primal, linearize_outs_thunk = linearize_subtrace(
f, self.tag, nzs_in, f.debug_info)
if isinstance(call_primitive, core.MapPrimitive):
@as_hashable_function(closure=(linearize_outs_thunk))
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
residual_avals, _, _ = linearize_outs_thunk()
out_axes = params['out_axes_thunk']()
return (*(0 for _ in residual_avals), *out_axes)
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
out_axes = out_axes_thunk()
return (*(0 for _ in range(num_res_out)), *out_axes)
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
primal_params = params
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
non_fwd_res = all_primal_results[:num_res_out]
primals_out = all_primal_results[num_res_out:]
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
if isinstance(call_primitive, core.MapPrimitive):
in_axes = params['in_axes']
out_axes = params['out_axes_thunk']()
residual_avals = map(get_aval, residuals)
new_in_axes = (*(0 for _ in residual_avals),
residual_axes = [in_axes[f1] if f1 is not None else
out_axes[f2] if f2 is not None else
0 for f1, f2 in zip(in_fwd, out_fwd)]
new_in_axes = (*residual_axes, *(None for _ in range(len(env))),
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
# NOTE: This assumes that the output tangents being zero is a
# deterministic function of which input tangents were zero.
@as_hashable_function(closure=(new_out_axes))
@as_hashable_function(closure=new_out_axes)
def new_out_axes_thunk():
return new_out_axes
params = dict(params,
in_axes=new_in_axes,
out_axes_thunk=new_out_axes_thunk)
params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
update_params = call_linearize_param_updaters.get(call_primitive)
new_params = update_params(params, residual_avals, nzs_in) if update_params else params
num_new_args = len(residuals) + len(env)
new_params = update_params(params, num_new_args, nzs_in) if update_params else params
num_residuals = len(residual_avals)
@as_hashable_function(closure=(num_residuals, lin_jaxpr))
def f_tangent(*args):
residuals = args[:num_residuals]
consts = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)
# TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to
# avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses.
# Remove when we replace the pmap implementation.
f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent,
debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), new_params)
self.tangent_trace,
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
*residuals, *env, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
@ -762,14 +815,14 @@ def fallback_linearize_rule(_prim: core.Primitive,
msg = f"Differentiation rule for '{_prim}' not implemented"
raise NotImplementedError(msg)
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
debug_jvp, primals, params)
return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp),
_prim.multiple_results, _nonzeros, False, False,
primals, params)
def linearize_from_jvp(jvp: Callable,
def linearize_from_jvp(jvp: lu.WrappedFun,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
@ -792,13 +845,18 @@ def linearize_from_jvp(jvp: Callable,
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
for aval, nz in zip(tangent_avals, nonzeros))
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)
out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params)
if not multiple_results:
out_primals = [out_primals]
out_tangents = [out_tangents]
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
if any(p is None for p in out_primals):
raise ValueError(
"Linearization failed to produce known values for all output primals. "
"This is typically caused by attempting to differentiate a function "
"uses an operation that does not support reverse-mode autodiff.")
out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known()
for t in out_tangents]
@ -806,7 +864,7 @@ def linearize_from_jvp(jvp: Callable,
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info)
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
@ -973,9 +1031,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
@ -1013,9 +1070,8 @@ def map_transpose(primitive: core.Primitive, params,
call_jaxpr: core.Jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
# TODO(necula): use the right debug_info for the backwards pass
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
fun, nz_arg_cts = nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
@ -1083,8 +1139,8 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=jaxpr.jaxpr.debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
f_jvp, out_nonzeros = f_jvp_traceable(
jvp(f, instantiate=instantiate, transform_stack=False), nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(

View File

@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None:
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
if dtypes.float4_e2m1fn is not None:
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
# TODO Support different-size underlying dtypes to take advantage of the
@ -939,7 +942,7 @@ def sharded_aval(aval: core.AbstractValue,
return aval
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
raise NotImplementedError
return aval.update(sharding.shard_shape(aval.shape)) # type: ignore
return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore
def eval_dynamic_shape(ctx: LoweringRuleContext,

View File

@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache, subs_list)
as_hashable_function, weakref_lru_cache, subs_list,
HashableFunction)
map, unsafe_map = safe_map, map
@ -837,6 +838,11 @@ def tracers_to_jaxpr(
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@weakref_lru_cache
def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
constvars, envvars = partition_list(which, jaxpr.constvars)
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
@ -1840,7 +1846,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
__slots__ = ("frame", "tag")
def __init__(self, debug_info: core.DebugInfo):
self.frame = JaxprStackFrame(debug_info)
@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
def process_map(self, map_primitive, f: lu.WrappedFun,
tracers: Sequence[core.Tracer], params):
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
if in_axis is not None else a
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals)
jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace(
return tracer
return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env
else new_tracer(x) for x in jaxpr.outvars]
# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the
# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's
# handling of pmap. Remove when we replace the pmap implementation.
def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]:
if (not f.transforms and type(f.f) is HashableFunction and
getattr(f.f, '_pmap_tag', None)):
_, jaxpr = f.f.closure
return convert_constvars_jaxpr(jaxpr), []
return jaxpr, consts

View File

@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents):
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents):
donated_invars_prev = params['donated_invars']
donated_invars = (*(False for _ in residual_avals),
donated_invars = (*(False for _ in range(num_new_inputs)),
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
return dict(params, donated_invars=donated_invars)
@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum):
elif self.name == 'OUT_SHARDING':
return 'explicit output sharding'
elif self.name == 'CONTEXT_DEVICES':
return 'devices'
return 'context mesh'
return f'{self.name}'
@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys:
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
use_resource_env: bool = False
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
@functools.cached_property

View File

@ -27,7 +27,6 @@ from jax._src.lax import lax
from jax._src import effects
from jax._src import ad_util
from jax._src import state
from jax._src import util
from jax._src.util import weakref_lru_cache, safe_map, partition_list
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts(
# b[] <- 2.0
# in () }
canonical_ref_indices = []
canonical_non_ref_indices = []
canonical_refs: list[Any] = []
tracer_id_to_canonical_id = {}
all_nonref_consts = []
canonical_non_refs: list[Any] = []
tracer_id_to_canonical_ref_id = {}
tracer_id_to_canonical_non_ref_id = {}
canonical_ref_avals = []
all_nonref_const_avals = []
canonical_non_ref_avals = []
for consts, consts_avals in zip(all_consts, all_const_avals):
ref_indices = []
nonref_consts = []
nonref_const_avals = []
non_ref_indices = []
for c, aval in zip(consts, consts_avals):
tracer_id = id(c)
if isinstance(aval, state.AbstractRef):
tracer_id = id(c)
if tracer_id not in tracer_id_to_canonical_id:
if tracer_id not in tracer_id_to_canonical_ref_id:
canonical_id = len(canonical_refs)
canonical_refs.append(c)
tracer_id_to_canonical_id[tracer_id] = canonical_id
tracer_id_to_canonical_ref_id[tracer_id] = canonical_id
canonical_ref_avals.append(aval)
canonical_id = tracer_id_to_canonical_id[tracer_id]
canonical_id = tracer_id_to_canonical_ref_id[tracer_id]
ref_indices.append(canonical_id)
else:
nonref_consts.append(c)
nonref_const_avals.append(aval)
all_nonref_consts.append(nonref_consts)
all_nonref_const_avals.append(tuple(nonref_const_avals))
if tracer_id not in tracer_id_to_canonical_non_ref_id:
canonical_id = len(canonical_non_refs)
canonical_non_refs.append(c)
tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id
canonical_non_ref_avals.append(aval)
canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id]
non_ref_indices.append(canonical_id)
canonical_ref_indices.append(tuple(ref_indices))
canonical_non_ref_indices.append(tuple(non_ref_indices))
consts = [*canonical_refs, *util.concatenate(all_nonref_consts)]
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,))
consts = [*canonical_refs, *canonical_non_refs]
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
for i, jaxpr in enumerate(jaxprs))
return jaxprs, consts, all_out_trees
@weakref_lru_cache
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
all_nonref_const_avals):
canonical_non_ref_avals, canonical_non_ref_indices):
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
newvar = core.gensym(suffix='_')
unused_const_vars = [tuple(map(newvar, const_avals))
for const_avals in all_nonref_const_avals]
padded_ref_constvars = map(newvar, canonical_ref_avals)
padded_non_ref_constvars = map(newvar, canonical_non_ref_avals)
for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars):
padded_ref_constvars[canonical_id] = ref_var
const_prefix = util.concatenate(unused_const_vars[:i])
const_suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars,
*const_suffix]
for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars):
padded_non_ref_constvars[canonical_id] = non_ref_var
constvars = [*padded_ref_constvars, *padded_non_ref_constvars]
jaxpr = jaxpr.replace(constvars=constvars)
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
jaxpr.outvars, jaxpr.eqns)

View File

@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
num_consts = len(consts)
out_ = iter(out)
all_inputs = [*consts, *ops]
out = [
next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts])
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
for fwd in in_fwd
]
assert next(out_, None) is None

View File

@ -443,42 +443,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
num_trips, remainder = divmod(length, unroll)
if unroll != 1 and num_trips == 1 and remainder == 0:
# In that case, we explicitly want to fully unroll the loop. Put everything
# into the remainder block and avoid lowering to a while loop.
num_trips, remainder = 0, length
if unroll == 1:
xss = xs_
yss = _map(partial(_empty_array, (length,), None), y_avals)
yss = _map(partial(_empty_array, (length,), (None,)), y_avals)
else:
if remainder:
if not reverse:
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
else:
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals)
if num_trips:
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
else:
yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals)
def cond_fun(while_carry):
i, _, _ = while_carry
return i < num_trips
def body_fun(while_carry):
i_, carry, yss = while_carry
i = num_trips - i_ - 1 if reverse else i_
xs = [
slicing.dynamic_index_in_dim(
xs, i, keepdims=False, allow_negative_indices=False
)
for xs in xss
]
carry, ys = inner(unroll, carry, xs)
yss = [
slicing.dynamic_update_index_in_dim(
ys, upd, i, 0, allow_negative_indices=False
)
for ys, upd in zip(yss, ys)
]
return i_ + 1, carry, yss
def inner(n, carry, xs):
ys = []
if unroll == 1:
@ -493,10 +477,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
ys = list(reversed(ys)) if reverse else ys
return carry, _map(_stack, zip(*ys))
def body_fun(while_carry):
i_, carry, yss = while_carry
i = num_trips - i_ - 1 if reverse else i_
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False,
allow_negative_indices=False)
for xs in xss]
carry, ys = inner(unroll, carry, xs)
yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0,
allow_negative_indices=False)
for y, upd in zip(yss, ys)]
return i_ + 1, carry, yss
def cond_fun(while_carry):
i, _, _ = while_carry
return i < num_trips
if num_trips:
i = lax._const(num_trips, 0)
_, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss))
if unroll != 1:
if unroll != 1 and num_trips != 0:
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
else:
ys = yss
@ -512,7 +512,7 @@ def _split_leading(sz, x):
def _concat(a, b): return lax.concatenate([a, b], 0)
def _empty_array(prefix, length_spec, aval):
sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec))
sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec))
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
out_sharding=sharding)

View File

@ -16,6 +16,7 @@ from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
import dataclasses
import enum
import functools
from functools import partial
@ -2227,10 +2228,122 @@ def ragged_dot(
Results:
(m, n) shaped array with preferred_element_type element type.
"""
return ragged_dot_p.bind(lhs, rhs, group_sizes,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset)
return ragged_dot_general(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
@dataclasses.dataclass(frozen=True)
class RaggedDotDimensionNumbers():
"""Describes ragged, group, and dot dimensions for ragged dot general.
Args:
dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
rhs_batch_dims))`.
lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
dimensions.
rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
dimensions.
"""
dot_dimension_numbers: DotDimensionNumbers
lhs_ragged_dimensions: Sequence[int]
rhs_group_dimensions: Sequence[int]
def __init__(
self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
):
super().__setattr__(
'dot_dimension_numbers',
tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
)
super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
def _from_maybe_ragged(
dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
) -> DotDimensionNumbers:
return (
dot_dimension_numbers.dot_dimension_numbers
if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
else dot_dimension_numbers
)
# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
)
def ragged_dot_general(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
) -> Array:
"""Ragged matrix multiplication.
Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
most one group dimension.
Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
three modes, depending on the kind of the lhs ragged dimension:
1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
and `x...` are the lhs non-contracting dims outer to the ragged dim.
2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
to the rules above.
Args:
lhs: an array
rhs: an array
group_sizes: an array with integer element type
ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
specify the dot dimension numbers, lhs ragged dimension, and rhs group
dimension.
precision: Optional. Consistent with precision argument for
:func:`jax.lax.dot`.
preferred_element_type: Optional. Consistent with precision argument for
:func:`jax.lax.dot`.
group_offset: Optional. (1,) shaped array that indicates the group in
group_sizes to start computing from. If not specified, defaults to [0].
Results:
An array whose shape is the same as that produced by `dot_general`, with an
extra leading dimension of size `g` in the case where the lhs ragged
dimension is a contracting dimension.
"""
return ragged_dot_general_p.bind(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
@ -3723,10 +3836,11 @@ def _sin_complex(x):
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
two = _const(a, 2)
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = sn * csh, cs * snh
e1m, e2m = expm1(b), expm1(neg(b))
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
re, im = mul(sn, csh), mul(cs, snh)
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
@ -3736,14 +3850,14 @@ def _sin_lowering(ctx, x):
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
def _sin_p_lin(nzs, x):
def _sin_lin(nzs, x):
nz, = nzs
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
ad.primitive_linearizations[sin_p] = _sin_p_lin
ad.primitive_linearizations[sin_p] = _sin_lin
mlir.register_lowering(sin_p, _sin_lowering)
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
@ -3752,10 +3866,11 @@ def _cos_complex(x):
# see also _sin_complex
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
two = _const(a, 2)
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = cs * csh, -sn * snh
e1m, e2m = expm1(b), expm1(neg(b))
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
re, im = mul(cs, csh), mul(neg(sn), snh)
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
def _cos_lowering(ctx, x):
@ -3769,28 +3884,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, _cos_lowering)
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
asin_p = standard_unop(_float | _complex, 'asin')
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x)))))
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
acos_p = standard_unop(_float | _complex, 'acos')
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x))))))
mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos))
def atan_impl(x):
return atan2(x, _const(x, 1))
atan_p = standard_unop(_float | _complex, 'atan')
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x))))
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: g * (y / (square(x) + square(y))),
lambda g, x, y: g * -x / (square(x) + square(y)))
lambda g, x, y: mul(g, div(y, add(square(x), square(y)))),
lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y)))))
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
sinh_p = standard_unop(_float | _complex, 'sinh')
@ -3802,17 +3917,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x)))))
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
acosh_p = standard_unop(_float | _complex, 'acosh')
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
atanh_p = standard_unop(_float | _complex, 'atanh')
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x))))
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
real_p = unop(_complex_basetype, _complex, 'real')
@ -3906,11 +4021,11 @@ def _square_complex(x):
a, b = real(x), imag(x)
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
# where for finite a, 2 * a is non-finite:
zero_re = is_finite(a) & (eq(a, b) | eq(a, -b))
zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b)))
# equivalent to a**2 - b**2 but avoids overflow errors for large a
# and large b cases:
re = (a - b) * (a + b)
im = a * b * 2
re = mul(sub(a, b), add(a, b))
im = mul(mul(a, b), _const(a, 2))
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
def _square_lower_hlo(ctx, x):
@ -4591,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
out_sharding):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
@ -4652,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
rhs_group = ()
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
rhs_contract_or_batch_or_group = tuple(
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
)
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
@ -4721,7 +4841,7 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_sharding):
out_sharding, name: str = 'lax.dot_general'):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
del dimension_numbers # unused
@ -4742,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
result_dtype = rhs.dtype
else:
if lhs.dtype != rhs.dtype:
raise TypeError(
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
result_dtype = lhs.dtype
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
return _maybe_upcast(result_dtype, preferred_element_type,
@ -4882,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
# explicitly present dimensions that this dot_general is zipping together.
lbd, rbd = batch_dims
assert lbd is not None or rbd is not None
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
def bump_dims(dims, b):
return tuple(np.add(dims, np.greater_equal(dims, b)))
@ -4906,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
elif (type(rbd) is int and lbd is None):
# The right vmapped dimension becomes an additional tensor dimension in the
# batched dot_general.
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
rhs_tensor = list(
remaining(
range(rhs_ndim),
rhs_batch,
rhs_contract,
dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
)
)
result_batch_dim = (lhs_ndim - len(lhs_contract) +
int(sum(np.less(rhs_tensor, rbd))))
rhs_batch = bump_dims(rhs_batch, rbd)
@ -4917,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
assert False
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
if is_ragged_dot:
new_dimension_numbers = RaggedDotDimensionNumbers(
dot_dimension_numbers=new_dimension_numbers,
lhs_ragged_dimensions=bump_dims(
dimension_numbers.lhs_ragged_dimensions, lbd
),
rhs_group_dimensions=bump_dims(
dimension_numbers.rhs_group_dimensions, rbd
),
)
return new_dimension_numbers, result_batch_dim
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
@ -5008,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims):
lbd, rbd = batch_dims
return (lbd, rbd)
# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([2, 0], [1, 0]),
([], []),
)
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([3, 1], [2, 1]),
([0], [0]),
)
ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
@ -5184,58 +5311,181 @@ for platform in ["cpu", "tpu"]:
platform=platform)
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
if len(lhs.shape) == 3:
# Batched case
b, m, k = lhs.shape
b2, group_count, rk, n = rhs.shape
b3 = group_sizes.shape[0]
if b != b2:
raise TypeError(
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
f' {b2}.'
)
if b3 != b:
raise TypeError(
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
f' {b3} and {b}.'
)
if k != rk:
raise TypeError(
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
f' {rk}.'
)
num_groups = group_sizes.shape[1]
if group_count != num_groups:
raise TypeError(
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
f' {group_count} and {num_groups}.'
)
return (b, m, n)
class RaggedDotMode(enum.Enum):
RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]
RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n]
m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
num_groups = group_sizes.shape[0]
if group_count != num_groups:
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n)
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None,
**_) -> np.dtype:
def _ragged_dot_mode_and_dim(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> tuple[RaggedDotMode, int]:
assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
(lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
if lhs_ragged_dim in lhs_noncontracting:
mode = RaggedDotMode.RAGGED_NONCONTRACTING
elif lhs_ragged_dim in lhs_contracting:
mode = RaggedDotMode.RAGGED_CONTRACTING
elif lhs_ragged_dim in lhs_batch:
mode = RaggedDotMode.RAGGED_BATCH
else:
raise TypeError(
f'lhs_ragged_dim {lhs_ragged_dim} not found in '
f'lhs_noncontracting {lhs_noncontracting}, '
f'lhs_contracting {lhs_contracting}, or '
f'lhs_batch {lhs_batch}.'
)
return mode, lhs_ragged_dim
def _ragged_dot_mode(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> RaggedDotMode:
return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
def _is_ragged_contracting(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> bool:
return (
_ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
== RaggedDotMode.RAGGED_CONTRACTING
)
def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
batch, contract = map(list, (batch, contract))
noncontract = remaining(range(rank), contract, batch)
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
return batch + noncontract[: noncontract.index(ragged_dim)]
case RaggedDotMode.RAGGED_CONTRACTING:
return batch + contract[: contract.index(ragged_dim)]
case RaggedDotMode.RAGGED_BATCH:
return batch[: batch.index(ragged_dim)]
def _ragged_dot_general_shape_rule(
lhs,
rhs,
group_sizes,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
def _check_in_range(dim, rank, dim_name, arg_name):
if dim < 0 or dim >= rank:
raise TypeError(
f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
f'and less than the number of axes of the {arg_name} value, '
f'got {dim} for {arg_name} of rank {rank}.'
)
# Validate the lhs ragged dimension, and find out which mode we're in.
if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
raise TypeError(
'ragged_dot_general expects exactly one lhs ragged dimension.'
)
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
_check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
ragged_dot_dimension_numbers.dot_dimension_numbers
)
# Validate the shape of group_sizes, if it is something other than [g].
if group_sizes.ndim == 0:
raise TypeError('expected rank of group_sizes to be >=1.')
if group_sizes.ndim != 1:
# Construct the expected shape [b...,x...,g] of group_sizes.
prefix_dims = _ragged_dot_prefix_dims(
mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
)
expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
expected_gs_shape += (group_sizes.shape[-1],)
# TODO(pravnar): Permit other broadcastable shapes.
if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
raise TypeError(
'expected group_sizes to have shape '
f'{expected_gs_shape}, got {group_sizes.shape}.'
)
num_groups = group_sizes.shape[-1]
# Validate properties of the rhs group dimension(s).
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
match mode:
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
if len(rhs_group_dims) != 0:
raise TypeError(
'ragged_dot_general requires zero group dimensions in the rhs '
'when lhs ragged dimension is contracting or batch.'
)
case RaggedDotMode.RAGGED_NONCONTRACTING:
if len(rhs_group_dims) != 1:
raise TypeError(
'ragged_dot_general requires exactly one rhs group dimension '
'when lhs ragged dimension is noncontracting.'
)
rhs_group_dim = rhs_group_dims[0]
_check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
raise TypeError(
'ragged_dot_general requires rhs group dimension numbers to be '
'distinct from contracting and batch dimensions.'
)
if rhs.shape[rhs_group_dim] != num_groups:
raise TypeError(
'expected rhs group dimension size to be '
f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
)
out_shape = _dot_general_shape_rule(
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
)
if mode == RaggedDotMode.RAGGED_CONTRACTING:
out_shape = (num_groups,) + out_shape
return out_shape
def _ragged_dot_general_dtype_rule(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
raise TypeError(
'ragged_dot_general requires that '
'group_sizes.dtype is subtype of np.integer.'
)
# defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
return _dot_general_dtype_rule(
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
out_sharding=None)
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
name='lax.ragged_dot_general',
)
def _ragged_dot_jvp_rule(
primals, tangents, precision, preferred_element_type, group_offset
def _ragged_dot_general_jvp_rule(
primals, tangents, ragged_dot_dimension_numbers,
precision, preferred_element_type, group_offset
):
# note - we could ostensibly just get this by passing on the
# value to ragged_dot below, but, this feels cleaner.
@ -5245,20 +5495,22 @@ def _ragged_dot_jvp_rule(
dx, dy, _ = tangents # no tan on the gs
# primal
primal_out = ragged_dot(
primal_out = ragged_dot_general(
x,
y,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
# tangent
dx_out = (
ragged_dot(
ragged_dot_general(
dx,
y,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
@ -5266,73 +5518,127 @@ def _ragged_dot_jvp_rule(
else _zeros(primal_out)
)
dy_out = (
ragged_dot(
ragged_dot_general(
x,
dy,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
if type(dy) is not ad_util.Zero
else _zeros(primal_out)
)
tangent_out = dx_out + dy_out
tangent_out = add(dx_out, dy_out)
return primal_out, tangent_out
def _ragged_to_dense(x, y, group_sizes):
from jax._src.lax import control_flow # avoid circular imports
shape = (y.shape[0], x.shape[0], x.shape[1])
x = broadcast_in_dim(x, shape, [1, 2])
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = control_flow.cumsum(group_sizes)
group_starts = concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]],
dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
x = select(mask, x, _zeros(x))
return x
def _ragged_dot_transpose_rule(
ct, *operands, precision, preferred_element_type, group_offset
def _ragged_dot_general_transpose_rule(
ct,
x,
y,
group_sizes,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
group_offset: Array | None,
):
x, y, gs = operands
if group_offset is not None:
raise NotImplementedError('Unimplemented group_offset support.')
if ad.is_undefined_primal(y):
grad_x = None
else:
y_t = _matrix_transpose(y)
grad_x = ragged_dot(
ct,
y_t,
gs,
precision=precision,
preferred_element_type=preferred_element_type,
)
(x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
x_kept = remaining(range(x_ndim), x_contract, x_batch)
y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
x_ndim, ragged_dot_dimension_numbers
)
if ad.is_undefined_primal(x):
grad_y = None
else:
y = y.aval if ad.is_undefined_primal(y) else y
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
dimension_numbers = (([1], [1]), ([0], [0]))
grad_y = dot_general(
x_dense,
ct_dense,
dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
f'Unimplemented {fn_name} for ragged dot general in mode '
f'{ragged_dot_mode.name}.'
)
return grad_x, grad_y, None
# This is a hack to ensure we continue to emit the `_matrix_transpose` for the
# grad_x case. This isn't strictly necessary since we have dot_dim_nums.
# TODO(pravnar): Remove this once we no longer care to emit the transpose.
_is_basic_ragged_dot = (
x_ndim == 2
and y_ndim == 3
and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
)
def grad_x_dims():
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
dims = (
ragged_dot_dimension_numbers
if _is_basic_ragged_dot
else RaggedDotDimensionNumbers(
dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
lhs_ragged_dimensions=[
len(x_batch) + x_kept.index(lhs_ragged_dim)
],
rhs_group_dimensions=y_group,
)
)
x_contract_sorted_by_y = list(
np.take(x_contract, np.argsort(y_contract))
)
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
raise unimplemented('grad_x_dims', mode)
return dims, unsorted_axes
def grad_y_dims():
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
dims = RaggedDotDimensionNumbers(
dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
lhs_ragged_dimensions=[lhs_ragged_dim],
rhs_group_dimensions=[],
)
y_contract_sorted_by_x = list(
np.take(y_contract, np.argsort(x_contract))
)
unsorted_axes = (
list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
)
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
raise unimplemented('grad_y_dims', mode)
return dims, unsorted_axes
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
dims, unsorted_axes = dims_fn()
ragged_dot_general_out = ragged_dot_general(
lhs, rhs, group_sizes, dims, precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset)
result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
if result.dtype != aval.dtype:
result = _convert_element_type(result, aval.dtype, aval.weak_type)
return result
x_bar = (
None
if ad.is_undefined_primal(y)
else _ragged_dot_grad(ct,
_matrix_transpose(y) if _is_basic_ragged_dot else y,
grad_x_dims,
x.aval)
)
y_bar = (
None
if ad.is_undefined_primal(x)
else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
)
return x_bar, y_bar, None
def _ragged_dot_batch_unpack_args(batched_args):
@ -5347,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims):
return (lbd, rbd)
def _ragged_dot_invoke_prim(
def _ragged_dot_general_invoke_prim(
group_sizes,
lhs,
rhs,
new_dimension_numbers,
new_ragged_dot_dimension_numbers,
precision,
preferred_element_type,
out_sharding,
):
del out_sharding
return ragged_dot(
return ragged_dot_general(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
def _ragged_dot_batch_rule(
def _ragged_dot_general_batch_rule(
axis_data,
batched_args,
batch_dims,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
return _dot_batch_rule(
invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
batched_out, result_batch_dim = _dot_batch_rule(
_ragged_dot_batch_unpack_args,
_ragged_dot_batch_unpack_dims,
invoke,
axis_data,
batched_args,
batch_dims,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
)
if _is_ragged_contracting(batched_args[0].ndim - 1,
ragged_dot_dimension_numbers):
result_batch_dim += 1
return batched_out, result_batch_dim
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
_ragged_dot_dtype_rule, 'ragged_dot')
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
ragged_dot_general_p = standard_primitive(
_ragged_dot_general_shape_rule,
_ragged_dot_general_dtype_rule,
'ragged_dot_general',
)
ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
def _ragged_dot_impl(
def _ragged_dot_general_impl(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
@ -5410,24 +5725,100 @@ def _ragged_dot_impl(
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
if len(lhs.shape) == 3:
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
else:
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
ragged_to_dense = _ragged_to_dense
def ragged_to_dense(x: Array, gs: Array, *, dim: int):
from jax._src.lax import control_flow # avoid circular imports
assert gs.ndim == 1
shape = gs.shape + x.shape
x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
iota = broadcasted_iota(gs.dtype, shape, dim+1)
group_ends = control_flow.cumsum(gs)
group_starts = concatenate(
[_zeros(gs)[:1], group_ends[:-1]],
dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
x = select(mask, x, _zeros(x))
return x
lhs = ragged_to_dense(lhs, rhs, group_sizes)
def batched_ragged_to_dense(dim, *x_in_axes: int):
if not x_in_axes:
return partial(ragged_to_dense, dim=dim)
x_axis, *rest = x_in_axes
decr = lambda d: d - 1 if d >= x_axis else d
return api.vmap(
batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
in_axes=(x_axis, 0),
)
return dot_general(
lhs,
rhs,
dimension_numbers=ragged_dot_dims,
incr = lambda dims: [d + 1 for d in dims]
# Expand the ragged `dim` of `x`, given its batching `axes`.
# The group axis from `gs` becomes the outermost axis of the result.
# Some examples:
# x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k]
# x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
def expand(x, dim, gs, *axes):
expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
return transpose(expanded, np.argsort(unsorted_dims))
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
lhs.ndim, ragged_dot_dimension_numbers
)
(l_contract, r_contract), (l_batch, r_batch) = (
ragged_dot_dimension_numbers.dot_dimension_numbers
)
l_prefix = _ragged_dot_prefix_dims(
mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
)
_dot_general = partial(
dot_general,
precision=precision,
preferred_element_type=preferred_element_type,
)
# TODO(pravnar): Permit other broadcastable shapes.
if group_sizes.ndim == 1:
group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
assert len(rhs_group_dims) == 1
return _dot_general(
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
rhs,
dimension_numbers=(
(incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
(incr(l_batch), r_batch),
),
)
case RaggedDotMode.RAGGED_CONTRACTING:
rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
r_prefix = _ragged_dot_prefix_dims(
mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
)
return _dot_general(
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
dimension_numbers=(
(incr(l_contract), incr(r_contract)),
([0] + incr(l_batch), [0] + incr(r_batch)),
),
)
case RaggedDotMode.RAGGED_BATCH:
return _dot_general(
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
)
mlir.register_lowering(ragged_dot_general_p,
mlir.lower_fun(_ragged_dot_general_impl,
multiple_results=False))
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,
@ -5802,8 +6193,12 @@ def _concatenate_transpose_rule(t, *operands, dimension):
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
if bdim is not None)
spec = next(core.get_aval(op).sharding.spec[bdim]
for op, bdim in zip(batched_args, batch_dims) if bdim is not None)
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
else broadcast(op, (size,))
else broadcast(
op, (size,), out_sharding=core.get_aval(op).sharding.with_spec(
(spec, *core.get_aval(op).sharding.spec)))
for op, bdim in zip(batched_args, batch_dims)]
return concatenate(operands, dimension + 1), 0

View File

@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lax import utils as lax_utils
from jax._src.lax.lax import _float, _complex, _int
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec as P
from jax._src.typing import Array, ArrayLike
# The following imports may be unused but they are needed to register the
# custom call targets defined in each module.
from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401
from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
def register_module_custom_calls(module):
if hasattr(module, "registrations"):
for platform, targets in module.registrations().items():
for name, value, api_version in targets:
ffi.register_ffi_target(
name, value, platform=platform, api_version=api_version
)
if hasattr(module, "batch_partitionable_targets"):
for name in module.batch_partitionable_targets():
ffi.register_ffi_target_as_batch_partitionable(name)
register_module_custom_calls(gpu_linalg)
register_module_custom_calls(gpu_solver)
register_module_custom_calls(gpu_sparse)
register_module_custom_calls(lapack)
# Top-level functions in alphabetical order.

View File

@ -1608,8 +1608,9 @@ def _dynamic_update_slice_sharding_rule(operand, update, *start_indices):
if operand.sharding != update.sharding:
raise TypeError(
"dynamic_update_slice update sharding must be equal to operand"
f" sharding, got update sharding {update.sharding} for operand sharding"
f" {operand.sharding}.")
" sharding, got update sharding"
f" {update.str_short(mesh_axis_types=True)} for operand sharding"
f" {operand.str_short(mesh_axis_types=True)}.")
return operand.sharding
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):

View File

@ -200,7 +200,7 @@ class _BaseMesh:
_mesh_object_dict = {} # type: ignore
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]]
class Mesh(_BaseMesh, contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.

View File

@ -45,8 +45,8 @@ RealNumeric = Any # Scalar jnp array or float
@export
@typing.runtime_checkable
class Initializer(Protocol):
@staticmethod
def __call__(key: Array,
def __call__(self,
key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
raise NotImplementedError

View File

@ -547,6 +547,9 @@ def _einsum(
if not last_contraction:
dot_general_out_sharding = None
elif out_sharding is not None and names != result_names:
if len(result_names) > len(out_sharding.spec):
out_sharding = out_sharding.with_spec(
out_sharding.spec._normalized_spec_for_aval(len(result_names)))
spec = out_sharding.spec
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
dot_general_out_sharding = NamedSharding(

View File

@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
if dtypes.float4_e2m1fn is not None:
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)

View File

@ -26,7 +26,7 @@ from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
from jax._src.util import safe_zip, safe_map
from jax._src.util import safe_zip, safe_map, set_module
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax.sharding import Sharding
@ -35,6 +35,8 @@ import numpy as np
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
export = set_module('jax.numpy')
_dtype = partial(dtypes.dtype, canonicalize=True)
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin
return SingleDeviceSharding(device)
else:
return device
@export
def ndim(a: ArrayLike) -> int:
"""Return the number of dimensions of an array.
JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object.
Returns:
An integer specifying the number of dimensions of ``a``.
Examples:
Number of dimensions for arrays:
>>> x = jnp.arange(10)
>>> jnp.ndim(x)
1
>>> y = jnp.ones((2, 3))
>>> jnp.ndim(y)
2
This also works for scalars:
>>> jnp.ndim(3.14)
0
For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property:
>>> x.ndim
1
"""
# Deprecation warning added 2025-2-20.
check_arraylike("ndim", a, emit_warning=True)
return np.ndim(a) # NumPy dispatches to a.ndim if available.
@export
def shape(a: ArrayLike) -> tuple[int, ...]:
"""Return the shape an array.
JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object.
Returns:
An tuple of integers representing the shape of ``a``.
Examples:
Shape for arrays:
>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)
This also works for scalars:
>>> jnp.shape(3.14)
()
For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:
>>> x.shape
(10,)
"""
# Deprecation warning added 2025-2-20.
check_arraylike("shape", a, emit_warning=True)
return np.shape(a) # NumPy dispatches to a.shape if available.
@export
def size(a: ArrayLike, axis: int | None = None) -> int:
"""Return number of elements along a given axis.
JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function
raises a :class:`TypeError` if the input is a collection such as a list or
tuple.
Args:
a: array-like object
axis: optional integer along which to count elements. By default, return
the total number of elements.
Returns:
An integer specifying the number of elements in ``a``.
Examples:
Size for arrays:
>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3
This also works for scalars:
>>> jnp.size(3.14)
1
For arrays, this can also be accessed via the :attr:`jax.Array.size` property:
>>> y.size
6
"""
# Deprecation warning added 2025-2-20.
check_arraylike("size", a, emit_warning=True)
return np.size(a, axis=axis) # NumPy dispatches to a.size if available.

View File

@ -25,8 +25,8 @@ from typing import Any, Callable, Protocol, Sequence
import jax
from jax import lax
from jax._src import api_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
@ -351,7 +351,7 @@ def _pull_block_spec(
jaxpr.constvars,
jaxpr.invars,
needed_invars,
jaxpr.eqns[:jaxpr.eqns.index(eqn)],
jaxpr.eqns[: jaxpr.eqns.index(eqn)],
debug_info=jaxpr.debug_info,
)
scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
@ -426,6 +426,7 @@ def make_kernel_function(
return tuple(s for s in shape if s is not None)
_no_aval = object()
def _get_block_aval(bs, aval):
if bs is pallas_core.no_block_spec or bs is None:
return _no_aval
@ -441,10 +442,12 @@ def make_kernel_function(
unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten(
in_tree, invar_usages
)
def sds_like(x):
if x is _no_aval:
return _no_aval
return jax.ShapeDtypeStruct(x.shape, x.dtype)
kernel_in_type = jax.tree.map(
sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals)
)
@ -688,8 +691,10 @@ def _eltwise_eval_rule(prim, ctx, x, **params):
def _eltwise_pull_rule(
prim: core.Primitive, ctx: PullRuleContext, block_spec: pallas_core.BlockSpec,
**params
prim: core.Primitive,
ctx: PullRuleContext,
block_spec: pallas_core.BlockSpec,
**params,
) -> Sequence[pallas_core.BlockSpec]:
del prim, ctx, params
return [block_spec]
@ -702,7 +707,9 @@ def _eltwise_usage_rule(
return [used_out]
def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.BlockSpec:
def _bcast_block_spec(
block_spec: pallas_core.BlockSpec, i: int
) -> pallas_core.BlockSpec:
def new_index_map(i, *args):
idx = block_spec.index_map(*args)
assert len(idx) == len(block_spec.block_shape)
@ -710,7 +717,9 @@ def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.
return idx
new_block_shape = util.tuple_update(block_spec.block_shape, i, 1)
return pallas_core.BlockSpec(new_block_shape, functools.partial(new_index_map, i))
return pallas_core.BlockSpec(
new_block_shape, functools.partial(new_index_map, i)
)
def _binop_usage_rule(prim, ctx, used_out: set[Usage]):
@ -945,7 +954,9 @@ def _dynamic_slice_rule(
return block_indices
new_block_spec = pallas_core.BlockSpec(block_spec.block_shape, new_index_map)
return [new_block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1)
return [new_block_spec] + [pallas_core.no_block_spec] * (
len(ctx.avals_in) - 1
)
@register_eval_rule(lax.concatenate_p)
@ -1348,7 +1359,8 @@ def _push_block_spec_jaxpr(
return env[atom]
def _write_block_spec(
atom: core.Atom, block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec
atom: core.Atom,
block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec,
):
if isinstance(atom, core.Literal):
return
@ -1374,7 +1386,9 @@ def _push_block_spec_jaxpr(
util.safe_map(_write_block_spec, eqn.outvars, out_block_specs)
out_block_specs = tuple(util.safe_map(_read_block_spec, jaxpr.outvars))
valid_block_spec = [bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec][0]
valid_block_spec = [
bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec
][0]
out_block_specs = tuple(
valid_block_spec if obs is pallas_core.no_block_spec else obs
for obs in out_block_specs
@ -1491,6 +1505,18 @@ def _convert_element_type_push_rule(
return block_spec
@register_push_block_spec_rule(lax.select_n_p)
def _select_n_push_rule(
ctx: PushRuleContext,
*args: pallas_core.BlockSpec,
):
del ctx
block_specs = [b for b in args if b is not pallas_core.no_block_spec]
if len(block_specs) > 1:
raise NotImplementedError('select_n with multiple inputs not supported yet')
return block_specs[0]
@register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p)
def _custom_jvp_call_push_rule(
ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_
@ -1500,9 +1526,7 @@ def _custom_jvp_call_push_rule(
@register_push_block_spec_rule(pjit.pjit_p)
def _pjit_push_rule(
ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_
):
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):
assert not jaxpr.consts
return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs)

View File

@ -32,28 +32,42 @@ def _get_aval(x):
return jax_core.raise_to_shaped(jax_core.get_aval(x))
def fuse(f, *, physicalize: bool = False):
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
"""Fuses a function into a single fusable.
Args:
f: The function to fuse.
physicalize: (experimental) whether to physicalize the function.
debug: Whether to print debug information.
There should be a single call to a `fusable` inside the body of `f`. `fuse`
returns a transformed function that will fuse the surrounding computation into
the fusable and invoke it.
"""
def wrapper(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
debug_info = api_util.debug_info('fuse', f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
flat_avals = [_get_aval(x) for x in flat_args]
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
if physicalize:
wrapper = fusable_dtype.physicalize(wrapper)
return wrapper
def decorator(f):
def wrapper(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
debug_info = api_util.debug_info("fuse", f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
flat_avals = [_get_aval(x) for x in flat_args]
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
if debug:
print("Jaxpr before fusion:")
print(jaxpr)
out_tree = out_tree_thunk()
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
if physicalize:
wrapper = fusable_dtype.physicalize(wrapper)
return wrapper
if f is not None:
return decorator(f)
return decorator
_fusable: dict[jax_core.Primitive, Any] = {}

View File

@ -159,6 +159,9 @@ py_library(
":core",
":primitives",
"//jax",
"//jax:core",
"//jax:source_info_util",
"//jax:util",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),

File diff suppressed because it is too large Load Diff

View File

@ -3222,6 +3222,14 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx):
if not isinstance(x.type.element_type, ir.F32Type):
raise ValueError("Only float32 is supported.")
return tpu.reciprocal(x, approx=approx)
lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
del ty
(out_aval,) = ctx.avals_out

View File

@ -207,6 +207,8 @@ class BufferedRef:
is_accumulator: whether this BufferedRef is an accumulator.
is_input_output: whether this BufferedRef is an input/output without
automatic accumulation.
swap: Tracks whether the BufferedRef slots need to be swapped before next
copy.
"""
spec: pl.BlockSpec # static metadata
dtype: Any # static metadata
@ -214,9 +216,14 @@ class BufferedRef:
window_ref: REF | None
accum_ref: REF | None
current_slot: ArrayRef | None
# TODO(ramiroleal): Unused by class. Remove argument from
# BufferedRef instantiations.
next_slot: ArrayRef | None
sem_recvs: SemaphoreTuple | None
sem_sends: SemaphoreTuple | None
# TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
# using this ref.
swap: ArrayRef | None
def tree_flatten(self):
return (
@ -227,6 +234,7 @@ class BufferedRef:
self.next_slot,
self.sem_recvs,
self.sem_sends,
self.swap,
),
(self.spec, self.dtype, self.buffer_type),
)
@ -240,7 +248,7 @@ class BufferedRef:
return BufferType
@classmethod
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
"""Create a BufferedRef.
Args:
@ -248,6 +256,7 @@ class BufferedRef:
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
Returns:
Initialized BufferedRef
@ -271,6 +280,7 @@ class BufferedRef:
next_slot=None,
sem_recvs=None,
sem_sends=None,
swap=None,
)
else:
memory_space = SMEM if spec.memory_space == SMEM else VMEM
@ -281,7 +291,7 @@ class BufferedRef:
window_ref=memory_space((2,) + block_shape, dtype),
accum_ref=accum_ref,
current_slot=SMEM((1,), jnp.int32),
next_slot=SMEM((1,), jnp.int32),
next_slot=None,
sem_recvs=(
None
if buffer_type is BufferType.OUTPUT
@ -292,23 +302,24 @@ class BufferedRef:
if buffer_type is BufferType.INPUT
else SemaphoreType.DMA((2,))
),
swap=SMEM((1,), jnp.bool) if needs_swap_ref else None,
)
@classmethod
def input(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.INPUT)
def input(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref)
@classmethod
def output(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.OUTPUT)
def output(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref)
@classmethod
def accumulator(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
def accumulator(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref)
@classmethod
def input_output(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT)
def input_output(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref)
@property
def block_shape(self):
@ -329,7 +340,7 @@ class BufferedRef:
if self.memory_space == VMEM:
return self.window_ref.at[buffer_slice]
else:
return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
return self.window_ref.at[(self.current_slot_index, *buffer_slice)]
@property
def is_input(self):
@ -355,6 +366,14 @@ class BufferedRef:
def is_input_output(self):
return self.buffer_type == BufferType.INPUT_OUTPUT
@property
def current_slot_index(self):
return self.current_slot[0]
@property
def next_slot_index(self):
return lax.rem(self.current_slot_index + 1, 2)
def bind_existing_ref(self, window_ref, indices):
"""For handling VMEM references, the pipeline aliases the existing ref."""
if self.memory_space == VMEM:
@ -373,12 +392,15 @@ class BufferedRef:
"""Initialize slot indices."""
if self.memory_space == VMEM: return
self.current_slot[0] = 0
self.next_slot[0] = 0
if self.swap is not None:
self.swap[0] = False
def swap_slots(self):
"""Switch to the next slot."""
if self.memory_space == VMEM: return
self.current_slot[0] = self.next_slot[0]
self.current_slot[0] = self.next_slot_index
if self.swap is not None:
self.swap[0] = False
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
@ -441,8 +463,9 @@ class BufferedRef:
"""Starts copy of HBM dma slice into the current slot."""
assert self.is_input
if self.memory_space == VMEM: return
next_slot = lax.rem(self.current_slot[0] + 1, 2)
self.next_slot[0] = next_slot
if self.swap is not None:
self.swap[0] = True
next_slot = self.next_slot_index
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
@ -455,8 +478,9 @@ class BufferedRef:
"""Starts copy of HBM dma slice from the current slot."""
assert self.is_output
if self.memory_space == VMEM: return
slot = self.current_slot[0]
self.next_slot[0] = lax.rem(slot + 1, 2)
if self.swap is not None:
self.swap[0] = True
slot = self.current_slot_index
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
@ -471,7 +495,7 @@ class BufferedRef:
if self.memory_space == VMEM: return
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
current_slot = self.current_slot[0]
current_slot = self.current_slot_index
tpu_primitives.make_async_copy(
src_ref.at[src_slice], # nb: doesn't matter
self.window_ref.at[current_slot].at[
@ -484,7 +508,8 @@ class BufferedRef:
"""Waits for output copy to finish."""
assert self.is_output
if self.memory_space == VMEM: return
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
# In a double buffer, previous slot is the same as next slot.
prev_slot = self.next_slot_index
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
@ -671,10 +696,7 @@ class Scheduler:
def _start():
if buffered_ref.is_input:
buffered_ref.copy_in(src_ref, self.indices)
# In the prologue this makes it so we wait on the prologue copy to finish.
# In other iterations this is the regular swap.
buffered_ref.swap_slots()
buffered_ref.swap_slots()
def wait_in(self, buffered_ref, src_ref, schedule=None):
if schedule is None:
@ -780,9 +802,32 @@ class Scheduler:
@self._named_scope("ep_finalize")
def _end():
if buffered_ref.is_output:
buffered_ref.swap_slots() # formally correct, not actually necessary.
buffered_ref.wait_out(dst_ref, self.indices)
def swap_slots(self, buffered_ref, hbm_ref, schedule=None):
if buffered_ref.swap is not None:
swap = buffered_ref.swap[0]
else:
# If we are not using an SMEM `swap` tensor to keep track of
# swaps needed, then all the copies into and out of BufferedRefs
# are done by direct calls to the `copy_in` and `copy_out`
# methods in the pipeline loop. To determine if the BufferedRef
# needs a swap of slots, we recalculate the copy-in/copy-out
# conditions.
if schedule is None:
schedule = _default_schedule
pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref)
pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref)
copied_in = pred_in & buffered_ref.is_input & ~self.last_step
copied_out = pred_out & buffered_ref.is_output
swap = copied_in | copied_out
@pl.when(swap)
@self._named_scope("ep_swap")
def _swap():
buffered_ref.swap_slots()
# END SCHEDULE --------------------------------------------------------------
@ -875,6 +920,7 @@ def make_pipeline_allocations(
in_specs=None,
out_specs=None,
should_accumulate_out=False,
needs_swap_ref=True,
):
"""Create BufferedRefs for the pipeline.
@ -887,6 +933,7 @@ def make_pipeline_allocations(
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
Returns:
A list of BufferedRefs, one corresponding to each ref specified in the
@ -905,12 +952,12 @@ def make_pipeline_allocations(
in_refs = refs[:num_in_specs]
out_refs = refs[num_in_specs:]
def make_input_bref(in_spec, in_ref):
return BufferedRef.input(in_spec, in_ref.dtype)
return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref)
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
def make_output_bref(out_spec, out_ref, accumulate):
if accumulate:
return BufferedRef.accumulator(out_spec, out_ref.dtype)
return BufferedRef.output(out_spec, out_ref.dtype)
return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref)
return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref)
out_brefs = jax.tree.map(
make_output_bref, out_specs, out_refs, should_accumulate_out)
return (*in_brefs, *out_brefs)
@ -1109,6 +1156,14 @@ def emit_pipeline(
scratches = ()
if allocations is None:
# run with inline scoped allocations
# Prefetch and postyeet are arbitrary functions that can copy
# into or out of any of the BufferedRefs. Thus, we need a ref
# for the scheduler to mark when the prefetch or postyeet
# functions perform a copy and the slots need to be
# swapped. Without prefetch and postyeet, the swapping logic can
# be performed without the need for state.
needs_swap_ref = prefetch is not None or postyeet is not None
return primitives.run_scoped(
lambda allocations: pipeline(
*refs,
@ -1125,7 +1180,9 @@ def emit_pipeline(
*refs,
in_specs=in_specs,
out_specs=out_specs,
should_accumulate_out=should_accumulate_out),
should_accumulate_out=should_accumulate_out,
needs_swap_ref=needs_swap_ref,
),
)
if isinstance(allocations, list):
allocations = tuple(allocations)
@ -1184,6 +1241,8 @@ def emit_pipeline(
lax.cond(step == 0,
lambda: postyeet(*brefs, scheduler),
lambda: None)
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return _next_index(indices, grid)

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import collections
from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence
from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence
import contextlib
import dataclasses
import functools
@ -25,6 +25,7 @@ import math
from typing import Any, Protocol, cast
import jax
from jax import api_util
from jax import lax
from jax._src import core as jax_core
from jax._src import linear_util as lu
@ -36,6 +37,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.mlir.dialects import math as math_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
@ -837,6 +839,29 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value:
)
def _lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
def lowering_rule(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(
fun
if multiple_results
else lambda *args, **params: (fun(*args, **params),),
params,
debug_info=api_util.debug_info(
"Pallas Mosaic GPU lower_fun", fun, args, params
),
)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
out = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, args, consts
)
return out if multiple_results else out[0]
return lowering_rule
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup)
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
@ -1162,6 +1187,9 @@ def _convert_element_type_lowering_rule_wg(
cur_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
new_dtype = mgpu_utils.dtype_to_ir_type(new_dtype)
if cur_dtype == new_dtype:
return x
if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8:
raise NotImplementedError("Conversion involving sub-byte types unsupported")
@ -1170,7 +1198,29 @@ def _convert_element_type_lowering_rule_wg(
from_integer = ir.IntegerType.isinstance(cur_dtype)
to_integer = ir.IntegerType.isinstance(new_dtype)
if from_float and to_float:
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
cur_ty_width = ir.FloatType(cur_dtype).width
new_ty_width = ir.FloatType(new_dtype).width
if cur_ty_width == new_ty_width:
# There is no instruction to perform conversions between two float types
# of the same width. Go through the next-larger standard type.
# TODO(bchetioui): support conversions between float types of width 8.
# Which larger type to pick will depend on the number of bits in the
# smallest exponent.
if cur_ty_width != 16:
raise NotImplementedError(
"Conversion between float types of width other than 16 not"
" supported"
)
larger_ty = ir.F32Type.get()
if x_aval.shape:
upcast_ty = ir.VectorType.get(x_aval.shape, larger_ty)
else:
upcast_ty = larger_ty
def convert(ty, x):
return arith_dialect.truncf(ty, arith_dialect.extf(upcast_ty, x))
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
convert = arith_dialect.truncf
else:
convert = arith_dialect.extf
@ -1190,10 +1240,26 @@ def _convert_element_type_lowering_rule_wg(
else:
convert = arith_dialect.uitofp
elif from_float and to_integer:
dst_width = mgpu_utils.bitwidth(new_dtype)
# We clamp the float value to the min/max integer destination value
# in order to match JAX/XLA casting behavior. Note that this differs
# from numpy casting behavior.
if mgpu_utils.is_signed(y_aval.dtype):
maxint = 2 ** (dst_width - 1) - 1
minint = -(2 ** (dst_width - 1))
convert = arith_dialect.fptosi
else:
maxint = 2**dst_width - 1
minint = 0
convert = arith_dialect.fptoui
maxint = _ir_constant(maxint, cur_dtype)
minint = _ir_constant(minint, cur_dtype)
if x_aval.shape:
maxint = vector_dialect.splat(x.type, maxint)
minint = vector_dialect.splat(x.type, minint)
x = arith_dialect.minimumf(x, maxint)
x = arith_dialect.maximumf(x, minint)
else:
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
@ -1206,6 +1272,13 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
lax.not_p: lambda ctx, x: ~x,
})
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({
lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False),
lax.not_p: _lower_fun(
lambda x: jnp.bitwise_xor(x, -1), multiple_results=False
),
})
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
@ -1342,48 +1415,98 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup)
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
if y == 2:
return x * x
return NotImplementedError
if y != 2:
raise NotImplementedError
return _square_lowering_rule(ctx, x)
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup)
def _square_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
return x * x
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
x = _ensure_fa(x, x_aval.dtype)
return x * x
if jnp.issubdtype(x_aval.dtype, jnp.integer):
return arith_dialect.muli(x, x)
if jnp.issubdtype(x_aval.dtype, jnp.floating):
return arith_dialect.mulf(x, x)
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
fastmath = (
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
)
return math_dialect.rsqrt(
_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath
)
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup)
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
fastmath = (
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
)
return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane)
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
def _logistic(x):
return 1.0 / (1 + lax.exp(-x))
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun(
_logistic, multiple_results=False
)
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = (
_lower_fun(_logistic, multiple_results=False)
)
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup)
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return a.exp(approx=ctx.module_ctx.approx_math)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math)
fastmath = (
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
)
return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return a.exp2(approx=ctx.module_ctx.approx_math)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math)
fastmath = (
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
)
return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup)
def _log_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math)
fastmath = (
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
)
return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane)
@ -1484,11 +1607,7 @@ def _debug_print_lowering_rule(
)
elif len(ctx.avals_in) == 1:
[arg] = args
@arg.foreach
def _(val, idx):
idx_fmt = ", ".join(["{}"] * len(idx))
fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}")
mgpu.debug_print(fmt_str, *idx, val, uniform=False)
arg.debug_print(fmt)
else:
raise NotImplementedError(
"debug_print only supports printing of scalar values, or a single array"
@ -1901,27 +2020,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(
lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup
)
def _bitcast_convert_type_lowering_rule(
ctx: LoweringRuleContext, operand, *, new_dtype
ctx: LoweringRuleContext, x, *, new_dtype
):
# TODO(petebu) Handle case where src and dst types have different bitwidths
[operand_aval] = ctx.avals_in
operand = _ensure_fa(operand, operand_aval.dtype)
src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
[x_aval] = ctx.avals_in
src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
if src_elem_type.width != dst_elem_type.width:
raise NotImplementedError(
f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they"
" have different widths"
)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
x = _ensure_ir_value(x, x_aval.dtype)
return arith_dialect.bitcast(
ir.VectorType.get(x_aval.shape, dst_elem_type), x
)
x = _ensure_fa(x, x_aval.dtype)
if ir.IntegerType.isinstance(dst_elem_type):
output_is_signed = mgpu_utils.is_signed(new_dtype)
else:
output_is_signed = None
return mgpu.FragmentedArray.bitcast(
operand, dst_elem_type, output_is_signed=output_is_signed
x, dst_elem_type, output_is_signed=output_is_signed
)

View File

@ -91,6 +91,7 @@ class BufferedRef:
self.smem_ref.at[slot], # pytype: disable=unsupported-operands
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
predicate=predicate,
commit_group=False,
)
@ -299,6 +300,8 @@ def emit_pipeline(
predicate=lax.bitwise_or(slices_changed, is_last_step),
)
gpu_primitives.commit_smem_to_gmem_group()
fetch_step = step + (max_concurrent_steps - delay_release)
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
@ -344,6 +347,8 @@ def emit_pipeline(
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
gpu_primitives.commit_smem_to_gmem_group()
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)
@ -578,6 +583,7 @@ def emit_pipeline_warp_specialized(
bref.copy_out(_get_slot(slot, ~bref.is_index_invariant),
indices,
predicate=slices_changed)
gpu_primitives.commit_smem_to_gmem_group()
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices, new_store_slices, next_body_carry)
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
@ -619,6 +625,8 @@ def emit_pipeline_warp_specialized(
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
gpu_primitives.commit_smem_to_gmem_group()
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)

View File

@ -86,6 +86,7 @@ def _copy_smem_to_gmem_lowering(
src_transforms_treedef,
dst_transforms_treedef,
has_user_predicate,
commit_group,
):
predicate = ctx.module_ctx.single_wg_lane_predicate
if has_user_predicate:
@ -106,6 +107,7 @@ def _copy_smem_to_gmem_lowering(
src_ref=src,
dst_ref=dst,
predicate=predicate,
arrive=commit_group,
**copy_params,
)
return ()
@ -119,7 +121,12 @@ def _copy_smem_to_gmem_lowering(
assert copy_params.get("swizzle") is None
assert not copy_params.get("gmem_transform")
mgpu.dialect.async_store(
src, dst, indices, slice_lengths, predicate=predicate
src,
dst,
indices,
slice_lengths,
predicate=predicate,
commit_group=commit_group, # type: ignore[call-arg]
)
return ()
@ -174,7 +181,11 @@ def _extract_smem_copy_params(transforms):
def copy_smem_to_gmem(
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
src: _Ref,
dst: _Ref,
predicate: jax.Array | None = None,
*,
commit_group: bool = True,
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
@ -183,6 +194,9 @@ def copy_smem_to_gmem(
dst: The GMEM reference to copy to.
predicate: A boolean indicating whether the copy should be performed. If
``None``, the copy is always performed.
commit_group: If ``True``, this and any previously uncommitted copies
are committed to a group and can be awaited jointly via
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`.
See also:
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
@ -209,6 +223,7 @@ def copy_smem_to_gmem(
src_transforms_treedef=src_transforms_treedef,
dst_transforms_treedef=dst_transforms_treedef,
has_user_predicate=predicate is not None,
commit_group=commit_group,
)
return None
@ -475,6 +490,28 @@ def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None:
wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only)
commit_group_p = jax_core.Primitive("commit_group")
commit_group_p.multiple_results = True
@commit_group_p.def_effectful_abstract_eval
def _commit_group_abstract_eval():
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup)
def _commit_group_lowering(ctx: lowering.LoweringRuleContext):
del ctx # Unused.
nvvm_dialect.cp_async_bulk_commit_group()
return ()
def commit_smem_to_gmem_group() -> None:
"""Commits all issued but uncommited SMEM->GMEM copies to a group."""
commit_group_p.bind()
# WGMMA on an accumulator reference
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

View File

@ -695,14 +695,44 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
if precision is not None:
raise ValueError("Only one of allow_tf32 and precision can be specified")
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
dtype = jnp.promote_types(a.dtype, b.dtype)
out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32
return jax.lax.dot_general(
a,
b,
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=jnp.float32,
preferred_element_type=out_dtype,
)
reciprocal_p = jax_core.Primitive("reciprocal")
def reciprocal(x, *, approx=False):
return reciprocal_p.bind(x, approx=approx)
@reciprocal_p.def_abstract_eval
def _reciprocal_abstract_eval(x, *, approx):
del approx
return x
def _reciprocal_lowering_rule(
ctx: mlir.LoweringRuleContext, x, *, approx=False
):
def _reciprocal(x, *, approx=False):
if approx:
return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32)
return jnp.reciprocal(x)
return mlir.lower_fun(_reciprocal, multiple_results=False)(
ctx, x, approx=approx
)
mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule)
class PrintEffect(effects.Effect):
__str__ = lambda self: "Print"

View File

@ -12,35 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
import logging
import os
import pathlib
import importlib.util
__all__ = ["Path"]
logger = logging.getLogger(__name__)
epath_installed: bool
class PathProtocol(Protocol):
"""A factory that creates a PurePath."""
def __call__(self, *pathsegments: str | os.PathLike) -> pathlib.Path:
...
Path: PathProtocol
# If etils.epath (aka etils[epath] to pip) is present, we prefer it because it
# can read and write to, e.g., GCS buckets. Otherwise we use the builtin
# pathlib and can only read/write to the local filesystem.
epath_installed = bool(
importlib.util.find_spec("etils") and
importlib.util.find_spec("etils.epath")
)
if epath_installed:
logger.debug("etils.epath found. Using etils.epath for file I/O.")
def __dir__():
return ["Path"]
def __getattr__(name):
if name != "Path":
raise AttributeError(f"module '{__name__}' has no attribute '{name}")
global Path
from etils import epath
Path = epath.Path
return Path
else:
try:
from etils import epath # type: ignore
except ImportError:
logger.debug("etils.epath was not found. Using pathlib for file I/O.")
Path = pathlib.Path
epath_installed = False
else:
logger.debug("etils.epath found. Using etils.epath for file I/O.")
# Ultimately, epath.Path implements pathlib.Path. See:
# https://github.com/google/etils/blob/2083f3d932a88d8a135ef57112cd1f9aff5d559e/etils/epath/abstract_path.py#L47
Path = epath.Path
epath_installed = True

View File

@ -29,7 +29,6 @@ import warnings
import numpy as np
from jax._src import api
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
@ -199,10 +198,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
profiler = None
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, p.arg_names)
fun_name, fails, args_flat, 'jit', p.arg_names)
raise ValueError(msg) from None
except xla.InvalidInputException as e:
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
@ -359,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env,
compiler_options_kvs=jit_info.compiler_options_kvs)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
@ -546,8 +543,7 @@ class PjitParams(NamedTuple):
def _infer_params_impl(
fun: Callable,
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
ctx_mesh: mesh_lib.Mesh | None,
dbg: core.DebugInfo,
args: tuple[Any, ...],
kwargs: dict[str, Any],
@ -559,8 +555,8 @@ def _infer_params_impl(
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if pjit_mesh is not None:
if (ji.backend or ji.device) and not pjit_mesh.empty:
if ctx_mesh is not None:
if (ji.backend or ji.device) and not ctx_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
@ -591,13 +587,12 @@ def _infer_params_impl(
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
_create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit')
for x in ji.in_shardings_leaves)
in_shardings_treedef = ji.in_shardings_treedef
out_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
_create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit')
for x in ji.out_shardings_leaves)
out_shardings_treedef = ji.out_shardings_treedef
@ -655,8 +650,8 @@ def _infer_params_impl(
out_shardings=out_shardings_flat,
in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=fun_qual_name(flat_fun),
keep_unused=ji.keep_unused,
inline=ji.inline,
@ -686,38 +681,30 @@ def _infer_params_cached(
jit_info: PjitInfo,
signature: jax_jit.ArgumentSignature,
in_avals: tuple[core.AbstractValue, ...],
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
ctx_mesh: mesh_lib.Mesh | None,
) -> InferParamsCacheEntry:
return InferParamsCacheEntry()
def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together():
if (not mesh_lib.thread_resources.env.physical_mesh.empty and
mesh_lib.get_concrete_mesh() is not None):
raise ValueError(
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed.')
def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
) -> tuple[PjitParams, list[Any]]:
if ji.use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
else:
resource_env = None
pjit_mesh = None
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
return _infer_params_internal(fun, ji, args, kwargs)
return _infer_params_internal(fun, ji, args, kwargs)
def _infer_params_internal(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
ctx_mesh = mesh_lib.get_concrete_mesh()
dbg = debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg,
p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg,
args, kwargs, in_avals=None)
return p, p.consts + args_flat
@ -725,10 +712,11 @@ def _infer_params(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
avals = _infer_input_type(fun, dbg, dynargs)
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals)
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
if p.attrs_tracked: # if attrs, don't popoulate the cache
return p, p.consts + args_flat
entry.pjit_params = p
@ -1619,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler,
compiler_options_kvs):
in_shardings = _resolve_in_shardings(args, in_shardings)
@ -1627,8 +1615,8 @@ def _resolve_and_lower(
jaxpr.in_avals)
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
return _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline, compiler_options_kvs,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
@ -1637,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
pgle_compile_options, pgle_profiler = {}, None
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
@ -1662,8 +1650,8 @@ def _pjit_call_impl_python(
compiled = _resolve_and_lower(
args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, lowering_platforms=None,
lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler,
@ -1694,7 +1682,7 @@ def _pjit_call_impl_python(
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name,
out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
@ -1708,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, compiler_options_kvs=compiler_options_kvs)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
@ -1724,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr,
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
cache_key = pxla.JitGlobalCppCacheKeys(
@ -1733,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
@ -1749,8 +1736,8 @@ def _pjit_lower(
out_shardings,
in_layouts: pxla.MaybeLayout,
out_layouts: pxla.MaybeLayout,
resource_env,
donated_invars,
ctx_mesh,
name: str,
keep_unused: bool,
inline: bool,
@ -1760,14 +1747,10 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if resource_env is not None:
mesh, api_name = resource_env.physical_mesh, 'pjit'
else:
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
jaxpr, 'jit', name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, context_mesh=mesh,
keep_unused=keep_unused, context_mesh=ctx_mesh,
compiler_options_kvs=compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
@ -1919,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
jaxpr: core.ClosedJaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline, compiler_options_kvs):
out_shardings, in_layouts, out_layouts, donated_invars,
ctx_mesh, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
@ -1929,7 +1912,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
func = _pjit_cached_lower_jaxpr_to_fun(
ctx, name, jaxpr, tuple(effects), in_shardings,
out_shardings, in_layouts, out_layouts,
api_name=('jit' if resource_env is None else 'pjit'))
api_name='jit')
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
args = (*ctx.dim_var_values, *tokens_in, *args)
@ -1950,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in,
dims_in: tuple[int, ...],
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
if resource_env is not None:
mesh = resource_env.physical_mesh
else:
mesh = None
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
in_shardings = tuple(
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim)
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_in is not None else i
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
out_shardings = tuple(
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim)
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
# TODO(yashkatariya): Figure out layouts should change under vmap.
@ -1982,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2005,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val):
def _pjit_batcher_for_sharding(
s: Sharding | UnspecifiedValue,
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
ndim: int):
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
mesh, ndim: int):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim)
@ -2045,20 +2025,8 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
primals_in = [*primals_in, *mut_primals]
tangents_in = [*tangents_in, *mut_tangents]
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals)
in_layouts = (*in_layouts,) + (None,) * len(mut_primals)
donated_invars = (*donated_invars,) + (False,) * len(mut_primals)
tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x
for x, a in zip(tangents_in, jaxpr.in_avals)]
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
@ -2074,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in,
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
resource_env=resource_env,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2091,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_linearization(nzs, *primals_in, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
# constvars will become residuals. Move them to the end of the ordinary args.
@ -2107,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
out_layouts=_filter_zeros(nzs_out, out_layouts),
resource_env=resource_env,
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2127,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
out_shardings=(*res_shardings, *out_shardings),
in_layouts=in_layouts,
out_layouts=(*res_layouts, *out_layouts),
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2143,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization
def _pjit_partial_eval(trace: pe.JaxprTrace,
*in_tracers,
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
in_layouts, out_layouts, donated_invars, ctx_mesh,
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
@ -2210,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
out_shardings=known_out_shardings,
in_layouts=keep_where(in_layouts, known_ins),
out_layouts=known_out_layouts, resource_env=resource_env,
out_layouts=known_out_layouts,
donated_invars=keep_where(donated_invars, known_ins),
ctx_mesh=ctx_mesh,
name=name, keep_unused=keep_unused, inline=inline,
compiler_options_kvs=compiler_options_kvs)
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
@ -2242,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
out_shardings=keep_where(out_shardings, unknown_outs),
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
out_layouts=keep_where(out_layouts, unknown_outs),
resource_env=resource_env,
donated_invars=(keep_where(donated_invars, unknown_ins) +
(False,) * num_residuals),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2330,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun,
def _pjit_transpose(cts_in, *primals_in,
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -2379,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in,
out_shardings=transpose_out_shardings,
in_layouts=transpose_in_layouts,
out_layouts=transpose_out_layouts,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2464,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn,
del params['out_layouts']
if not params['keep_unused']:
del params['keep_unused']
if (params['resource_env'] is None or
params['resource_env'].physical_mesh.empty):
del params['resource_env']
if params['ctx_mesh'] is None or params['ctx_mesh'].empty:
del params['ctx_mesh']
if not params['compiler_options_kvs']:
del params['compiler_options_kvs']
@ -2536,6 +2504,11 @@ def with_sharding_constraint(x, shardings):
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
of how to use this function, see `Distributed arrays and automatic parallelization`_.
Inside of a jitted computation, with_sharding_constraint makes it possible to
constrain intermediate values to an uneven sharding. However, if such an
unevenly sharded value is output by the jitted computation, it will come out
as fully replicated, no matter the sharding annotation given.
Args:
x: PyTree of jax.Arrays which will have their shardings constrained
shardings: PyTree of sharding specifications. Valid values are the same as for
@ -2561,8 +2534,6 @@ def with_sharding_constraint(x, shardings):
flatten_axes("with_sharding_constraint layouts", tree, layouts))
del layouts
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
context_mesh = (
mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None
else mesh_lib.thread_resources.env.physical_mesh)

View File

@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None:
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
if _dtypes.float4_e2m1fn is not None:
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
if _dtypes.float4_e2m1fn is not None:
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:

View File

@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents):
return expi(x), jnp.exp(x) / x * x_dot
def _expn1(n: Array, x: Array) -> Array:
def _expn1(x: Array, n: Array) -> Array:
# exponential integral En
_c = _lax_const
MACHEP = jnp.finfo(x.dtype).eps
@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array:
return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]
def _expn2(n: Array, x: Array) -> Array:
def _expn2(x: Array, n: Array) -> Array:
# x > 1.
_c = _lax_const
BIG = _c(x, 1.44115188075855872e17)
@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array:
return d["ans"] * jnp.exp(-x)
def _expn3(n: Array, x: Array) -> Array:
def _expn3(x: Array, n: Array) -> Array:
# n >= 5000
_c = _lax_const
one = _c(x, 1.0)
@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array:
jnp.inf,
one / n1, # prevent div by zero
jnp.exp(-x) / x,
partial(_expn3, n),
partial(_expn2, n),
partial(_expn1, n),
_expn3,
_expn2,
_expn1,
]
ret = jnp.piecewise(x, conds, vals)
ret = jnp.piecewise(x, conds, vals, n=n)
return ret

View File

@ -305,7 +305,16 @@ class SetNameStackContextManager(contextlib.ContextDecorator):
set_name_stack = SetNameStackContextManager
reset_name_stack = lambda: SetNameStackContextManager(NameStack())
# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
# implementation doesn't work. Luckily this context manager isn't called much so
# the performance shouldn't matter. See blame commit message for repro.
# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
@contextlib.contextmanager
def reset_name_stack() -> Iterator[None]:
with set_name_stack(NameStack()):
yield
class TransformNameStackContextManager(contextlib.ContextDecorator):

View File

@ -259,3 +259,26 @@ class NDIndexer:
def transform_dtype(self, dtype):
return dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
# If there are explicit axes, we don't support changing the shape, so we
# don't support int indexers and instead require all slices.
if (self.int_indexer_shape or
not all(isinstance(idx, Slice) for idx in self.indices)):
raise TypeError("sharded ref (array reference) can only be indexed by "
"slices, not integers")
# Moreover, only allow trivial slice(None) slices on explicitly sharded
# axes. Then the sharding stays the same.
_, slice_indexers, _ = unpack_ndindexer(self)
for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)):
if s is None: continue
if not (type(sl.start) is int and sl.start == 0 and
type(sl.size) is int and sl.size == d and
type(sl.stride) is int and sl.stride == 1):
raise ValueError("sharded ref (array reference) can only be sliced "
f"along unsharded axes, but ref of shape {self.shape} "
f"was sliced on axis {i}, which is sharded like {s}")
return sharding

View File

@ -206,6 +206,13 @@ def _dtype_after_transforming(
return dtype
def _sharding_after_transforming(sharding, transforms):
for transform in transforms:
sharding = transform.transform_sharding(sharding)
assert sharding is not None
return sharding
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
transforms = tree_util.tree_unflatten(tree, args)
@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
if isinstance(ref_aval.inner_aval, core.ShapedArray):
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
# TODO(yashkatariya): Transform the sharding too instead of setting it to
# None.
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
sharding=core.get_cur_mesh_sharding())
out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms)
out_aval = ref_aval.inner_aval.update(
shape=out_shape, dtype=out_dtype, sharding=out_sharding)
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
@ -437,6 +443,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
# if type(ref_tangent) is ad_util.Zero:
# raise Exception("you're an idiot")
assert isinstance(ref_tangent.aval, AbstractRef)
x_tangent = ad_util.instantiate(x_tangent)
return (swap_p.bind(ref_primal, x_primal, *idx, **params),
@ -657,5 +665,14 @@ mlir.register_lowering(
# === AD rules for mutable arrays ===
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
def _mut_jvp(primals, tangents):
(init_val,), (init_val_dot,) = primals, tangents
primal_out = core.mutable_array_p.bind(init_val)
if type(init_val_dot) is ad_util.Zero:
tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval))
else:
tangent_out = core.mutable_array_p.bind(init_val_dot)
return primal_out, tangent_out
ad.primitive_jvps[core.mutable_array_p] = _mut_jvp
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))

View File

@ -119,6 +119,12 @@ class RefBitcaster:
del dtype # Unused
return self.dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
@ -166,6 +172,12 @@ class RefReshaper:
del dtype # Unused
return self.dtype
def transform_sharding(self, sharding):
# If there are no explicit axes, do nothing.
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError
class Transform(Protocol):
@ -189,6 +201,10 @@ class Transform(Protocol):
"""
return dtype
def transform_sharding(self, sharding):
if all(p is None for p in sharding.spec): return sharding # no explicit axes
raise NotImplementedError
@dataclasses.dataclass
class RefIndexer:

View File

@ -1640,6 +1640,8 @@ class _LazyDtypes:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
if _dtypes.float4_e2m1fn is not None:
float_dtypes += [_dtypes.float4_e2m1fn]
return self.supported(float_dtypes)
@_cached_property

View File

@ -14,7 +14,7 @@
"""Colocated Python API."""
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
# pylint: disable=useless-import-alias
from jax.experimental.colocated_python.api import (

View File

@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun(
devices = specialization.devices
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
def lowered_fun(*args, **kwargs) -> jax.Array:
result = info.fun(*args, **kwargs)
result_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)

View File

@ -13,11 +13,9 @@
# limitations under the License.
"""Colocated Python serialization utilities."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
from __future__ import annotations
import base64
import collections
import functools
import io
@ -37,12 +35,6 @@ import numpy as np
DeviceList = xc.DeviceList
# Hard-coded limit for serialized specs size.
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
_MAX_SERIALIZED_SPECS_SIZE = 1048576
@jax.util.cache(max_size=None)
def _get_cpu_device_map() -> dict[int, jax.Device]:
"""Returns a map from a device id to a matching device."""
@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any:
def _make_specs_for_serialized_specs(
devices: DeviceList,
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
) -> api.ShapeDtypeStruct:
"""Makes output specs for serialized specs."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
return (
api.ShapeDtypeStruct(
shape=(), dtype=np.int32, sharding=replicated_sharding
),
api.ShapeDtypeStruct(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
dtype=np.uint8,
sharding=replicated_sharding,
),
return api.ShapeDtypeStruct(
shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore
)
@ -209,49 +192,49 @@ def _serialize_specs(
specs_treedef: tree_util.PyTreeDef,
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: DeviceList,
) -> tuple[jax.Array, ...]:
"""Serializes the output specs into a tuple of arrays.
) -> jax.Array:
"""Serializes the output specs into a jax.Array of string type.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
s = _serialize((specs_treedef, specs_leaves))
assert (
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
), f"Too large serialized spec size: {len(s)}"
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
if not hasattr(np.dtypes, "StringDType"):
raise TypeError(
"Serializing Colocated Python requires StringDType. Please use"
" numpy to 2.0.0 or later, or explicityly provide an output spec"
" function."
)
s_bytes = _serialize((specs_treedef, specs_leaves))
s_str = base64.b64encode(s_bytes).decode("ascii")
s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore
# TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making
# jax.Array via make_array_from_single_device_arrays. We should then use a
# sharding that spans all the execution devices - not just the addressable
# ones.
addressable_devices = devices.addressable_device_list
mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
len_array = jax.make_array_from_callback(
shape=(),
sharding=replicated_sharding,
data_callback=lambda _: np.array(len(s), dtype=np.int32),
out_arrays = [
jax.device_put(s_np_array, device) for device in addressable_devices
]
return jax.make_array_from_single_device_arrays(
arrays=out_arrays, sharding=replicated_sharding, shape=(),
)
data_array = jax.make_array_from_callback(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
sharding=replicated_sharding,
data_callback=lambda _: np.frombuffer(
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
dtype=np.uint8,
),
)
return len_array, data_array
def _deserialize_specs(
serialized_specs: tuple[jax.Array, ...],
serialized_specs: jax.Array,
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
"""Deserializes the specs from the serialized specs.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
len_array, data_array = serialized_specs
length = int(len_array.addressable_shards[0].data)
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
return _deserialize(data[:length])
data_array = serialized_specs.addressable_shards[0].data
data = base64.b64decode(data_array.item().encode("ascii"))
return _deserialize(data)

View File

@ -1541,10 +1541,12 @@ tf_not_yet_impl = [
"assert_consumed_value",
"consume",
"ragged_dot",
"ragged_dot_general",
"cholesky_update",
"symmetric_product",
"from_edtype",
"to_edtype",
"reciprocal",
# Pallas TPU primitives
"bitcast",
"repeat",
@ -3571,8 +3573,8 @@ def _pjit(*args: TfVal,
in_shardings: Sequence[sharding.Sharding],
out_shardings: Sequence[sharding.Sharding],
in_layouts, out_layouts,
resource_env: mesh.ResourceEnv,
donated_invars,
ctx_mesh,
name: str,
keep_unused: bool,
inline: bool,

View File

@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import math as mlir_math
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
@ -234,11 +235,6 @@ def _vector_load_op_lowering_rule(
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
)
if not layouts.is_strided_fragmented_layout(out_layout_attr):
raise ValueError(
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
)
for i in vector_load_op.indices:
index_defining_op = i.owner.opview
if (
@ -254,9 +250,28 @@ def _vector_load_op_lowering_rule(
element_type = vector_load_op.result.type.element_type
is_signed = False if ir.IntegerType.isinstance(element_type) else None
fragmented_array = fa.FragmentedArray.load_strided(
vector_load_op.base, is_signed=is_signed
)
if layouts.is_strided_fragmented_layout(out_layout_attr):
strided_layout = layouts.from_strided_fragmented_layout_attr(
out_layout_attr
)
fragmented_array = fa.FragmentedArray.load_strided(
vector_load_op.base,
is_signed=is_signed,
vec_size=strided_layout.vec_size,
)
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
layout = ir.MemRefType(vector_load_op.base.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref,
swizzle=swizzle,
is_signed=is_signed
)
else:
raise ValueError(
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
)
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
@ -424,10 +439,77 @@ def _mgpu_async_store_op_lowering_rule(
gmem_transform=transforms,
uniform=True,
predicate=ctx.single_thread_per_warpgroup_predicate,
arrive=store_op.commit_group,
)
return []
def _conversion_op_lowering_rule(
_: LoweringContext,
op: ir.OpView,
source_is_signed: bool | None,
target_is_signed: bool | None,
) -> Sequence[ir.Value]:
[in_layout] = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if in_layout != layout:
raise ValueError("Layout mismatch")
target_ty = op.result.type.element_type # pytype: disable=attribute-error
operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed)
converted = operand.astype(target_ty, is_signed=target_is_signed)
return [_fragmented_array_to_ir(converted, op.result.type)]
for op, source_is_signed, target_is_signed in [
(arith.ExtFOp, None, None),
(arith.ExtSIOp, True, True),
(arith.ExtUIOp, False, False),
(arith.FPToSIOp, None, True),
(arith.FPToUIOp, None, False),
(arith.SIToFPOp, True, None),
(arith.TruncFOp, None, None),
(arith.TruncIOp, False, False),
(arith.UIToFPOp, False, None),
]:
_lowerings[op.OPERATION_NAME] = functools.partial(
_conversion_op_lowering_rule,
source_is_signed=source_is_signed,
target_is_signed=target_is_signed,
)
def _unary_op_lowering_rule(
_: LoweringContext,
op: Any,
impl: Callable[[fa.FragmentedArray], fa.FragmentedArray],
is_signed: bool | None = None,
) -> Sequence[ir.Value]:
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
kwargs = {}
if hasattr(op, "fastmath"):
kwargs = dict(
approx=op.fastmath == ir.Attribute.parse("#arith.fastmath<afn>")
)
a = _fragmented_array_from_ir(op.operand, layout, is_signed)
return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)]
for op, impl, is_signed in [
(mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None),
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
(mlir_math.LogOp, fa.FragmentedArray.log, None),
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
]:
_lowerings[op.OPERATION_NAME] = functools.partial(
_unary_op_lowering_rule, impl=impl, is_signed=is_signed
)
def _binary_op_lowering_rule(
_: LoweringContext,
op: Any,
@ -525,6 +607,25 @@ def _cmpf_op_lowering_rule(
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
@_register_lowering(arith.BitcastOp)
def _bitcast_op_lowering_rule(
_: LoweringContext, op: arith.BitcastOp
) -> Sequence[ir.Value]:
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
in_ = _fragmented_array_from_ir(op.in_, layout)
out_element_type = ir.VectorType(op.result.type).element_type
out = in_.bitcast(
out_element_type,
output_is_signed=False
if ir.IntegerType.isinstance(out_element_type)
else None,
)
return [_fragmented_array_to_ir(out, op.result.type)]
@_register_lowering(mgpu.WGMMAOp)
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
@ -689,8 +790,12 @@ def _for_op_lowering_rule(
new_args = (new_for_op.induction_variable, *recreated_carry)
for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True):
old_carry.replace_all_uses_with(new_carry)
for op in ops_to_lower:
for op in ops_to_lower:
with ir.InsertionPoint(op):
ctx.lower_op(op)
with ir.InsertionPoint(new_for_op.body):
new_yield_operands = lower_carry(yield_op.operands)
yield_op.erase()
scf.yield_(new_yield_operands)

View File

@ -53,13 +53,12 @@ def build_kernel(
index = ir.IndexType.get()
swizzle = 128
tile_k = swizzle // 2
swizzle_elems = tile_k = swizzle // 2
tiling = (8, swizzle_elems)
in_dtype = jnp.float16
k_loop_iter = k // tile_k
max_concurrent_steps = min(max_concurrent_steps, k_loop_iter)
tma_tile_m = 128
tma_tile_kn = 64
block_tile_m = tile_m
block_tile_n = tile_n
@ -123,17 +122,14 @@ def build_kernel(
src_ref=a,
dst_ref=mgpu.memref_slice(a_smem, slot),
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)),
gmem_transform=mgpu.TileTransform(tiling),
**common_args,
)
ctx.async_copy(
src_ref=b,
dst_ref=mgpu.memref_slice(b_smem, slot),
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
gmem_transform=(
mgpu.TileTransform((tma_tile_kn, tma_tile_kn)),
mgpu.TransposeTransform((1, 0, 2, 3)),
),
gmem_transform=mgpu.TileTransform(tiling),
**common_args,
)
@ -145,7 +141,7 @@ def build_kernel(
tcgen05.mma(
acc,
mgpu.memref_slice(a_smem, slot),
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)),
a_swizzle=swizzle,
b_swizzle=swizzle,
accumulate=accumulate,
@ -172,26 +168,23 @@ def build_kernel(
src_ref=d_smem,
dst_ref=d,
gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)),
gmem_transform=mgpu.TileTransform((128, 64)),
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
swizzle=swizzle,
)
ctx.await_async_copy(0)
compute_buffers = (
jax.ShapeDtypeStruct(
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k),
(tma_tile_m, tma_tile_kn)),
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling),
jnp.float16),
jax.ShapeDtypeStruct(
mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n),
(tma_tile_kn, tma_tile_kn)),
mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling),
jnp.float16),
)
epilogue_buffer = jax.ShapeDtypeStruct(
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)),
jnp.float16)
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
assert block_tile_m == 128
smem = (
smem_buffers,
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,

View File

@ -22,6 +22,7 @@ import math
from collections.abc import Callable
from typing import Iterable, Protocol, Sequence, TypeVar
import itertools
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
@ -115,6 +116,65 @@ class Tiling:
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
return strides
def tile_nested_shape_strides(
self,
shape: tuple[tuple[int, ...], ...],
strides: tuple[tuple[int, ...], ...],
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
"""A fused version of `tile_shape` and `tile_strides` for nested shapes.
By nested shape we mean that each logical dimension (i.e. each element of
shape/strides) is actually composed out of multiple physical dimensions.
For example, a row-major array of logical shape (128, 128) that is tiled
into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each
dim is split into two sub-dims) and nested strides of
((2 * 64 * 64, 64), (64 * 64, 1)).
"""
if len(shape) != len(strides):
raise ValueError(
f"Shape {shape} and strides {strides} must have the same length"
)
def fail_if(cond, shape=shape): # Capture shape now.
if cond:
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
for tile in self.tiles:
fail_if(len(tile) > len(shape))
untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):]
untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):]
major_dim_shapes, major_dim_strides = [], []
minor_dim_shapes, minor_dim_strides = [], []
for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides):
major_dim_shape_rev, major_dim_stride_rev = [], []
minor_dim_shape_rev, minor_dim_stride_rev = [], []
for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True):
if d < t: # We will need to tile more dims
fail_if(t % d != 0)
t //= d
minor_dim_shape_rev.append(d)
minor_dim_stride_rev.append(s)
elif t != 1: # Last dim to tile!
fail_if(d % t != 0)
minor_dim_shape_rev.append(t)
minor_dim_stride_rev.append(s)
if d != t: # No need to insert singleton dims.
major_dim_shape_rev.append(d // t)
major_dim_stride_rev.append(s * t)
t = 1
else: # Done tiling!
major_dim_shape_rev.append(d)
major_dim_stride_rev.append(s)
fail_if(t != 1)
major_dim_shapes.append(major_dim_shape_rev[::-1])
minor_dim_shapes.append(minor_dim_shape_rev[::-1])
major_dim_strides.append(major_dim_stride_rev[::-1])
minor_dim_strides.append(minor_dim_stride_rev[::-1])
shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes)
strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides)
return (
tuple(tuple(d) if d else (1,) for d in shape),
tuple(tuple(d) if d else (1,) for d in strides),
)
def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
for tile in self.tiles:
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
@ -214,7 +274,7 @@ class TiledLayout:
index = ir.IndexType.get()
contig_strides = utils.get_contiguous_strides(shape)
tile_strides = self.tiling.tile_strides(contig_strides)
dyn_tile_strides = [c(s, i32) for s in tile_strides]
dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]]
warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides)
lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides)
dyn_offset = arith.addi(warp_offset, lane_offset)
@ -246,7 +306,12 @@ class TiledLayout:
so the tiled shape always ends with this suffix, no matter what array shape
it's applied to.
"""
return self.tiling.tile_shape(self.base_tile_shape)
base_tile_shape = self.base_tile_shape
return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):]
@functools.cached_property
def tiled_tiling_rank(self) -> int:
return len(self.tiled_tiling_shape)
@property
def vector_length(self) -> int:
@ -292,16 +357,12 @@ class TiledLayout:
def warp_indices(self) -> tuple[ir.Value, ...]:
i32 = ir.IntegerType.get_signless(32)
tiled_shape = tuple(
d if i == self.warp_dim else 1
for i, d in enumerate_negative(self.tiled_tiling_shape)
)
assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP
tiled_shape_rank = len(self.tiled_tiling_shape)
warp_idx = arith.remui(
arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)),
c(WARPS_IN_WARPGROUP, i32),
)
indices = [arith.constant(i32, 0)] * len(tiled_shape)
indices = [arith.constant(i32, 0)] * tiled_shape_rank
indices[self.warp_dim] = warp_idx
return tuple(indices)
@ -479,6 +540,33 @@ TILED_LAYOUT_WGMMA = TiledLayout(
lane_dims=(-4, -3),
vector_dim=-1,
)
# This tiled layout is similar to the one above. Above, each warp stores a 8x8
# submatrix in the following way (we only show the first 4 rows for brevity):
#
# 0 0 1 1 2 2 3 3
# 4 4 5 5 6 6 7 7
# 8 8 9 9 10 10 11 11
# 12 12 13 13 14 14 15 15
# ...
#
# This tiled layout stores the same 8x8 submatrix in the following way:
#
# 0 4 1 5 2 6 3 7
# 0 4 1 5 2 6 3 7
# 8 12 9 13 10 14 11 15
# 8 12 9 13 10 14 11 15
# ...
#
# You can see that we have taken 2x2 submatrices from the above layout and
# transposed them. The assigment of lanes to elements is such that in both
# layouts the same two lanes map to a single 2x2 submatrix, making the transpose
# very cheap (one shuffle and permute suffices to change between those layouts).
WGMMA_TRANSPOSED_LAYOUT = TiledLayout(
Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))),
warp_dim=-10,
lane_dims=(-6, -3, -5),
vector_dim=-2,
)
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True)
@ -553,13 +641,22 @@ class FragmentedArray:
raise NotImplementedError
@classmethod
def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None):
def load_strided(
cls,
ref: ir.Value,
*,
is_signed: bool | None = None,
vec_size: int | None = None,
):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
if vec_size is None:
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
else:
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
try:
# Flattening the reference potentially produces simpler PTX but
@ -647,9 +744,35 @@ class FragmentedArray:
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
"""
i32 = ir.IntegerType.get_signless(32)
if self.layout == new_layout:
return self
shape = self.shape
if (
self.layout == TILED_LAYOUT_WGMMA
and new_layout == WGMMA_TRANSPOSED_LAYOUT
and utils.bitwidth(self.mlir_dtype) == 16
):
is_even_row = arith.cmpi(
arith.CmpIPredicate.eq,
arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)),
c(0, i32),
)
perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32))
new_regs = []
for reg in self.registers.flat:
reg_ty = reg.type
reg = utils.bitcast(reg, i32)
reg_shfl = utils.shfl_bfly(reg, 4)
new_reg = llvm.inline_asm(
i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r"
)
new_regs.append(utils.bitcast(new_reg, reg_ty))
return FragmentedArray(
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
_layout=new_layout,
_is_signed=self.is_signed,
)
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0:
tiled_layout = _tiled_wgmma_layout(shape)
if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or (
@ -966,6 +1089,24 @@ class FragmentedArray:
return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32"))
return self._pointwise(mlir_math.exp2)
def log(self, *, approx: bool = False):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
if approx:
dtype = self.mlir_dtype
ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453))
return self.log2(approx=True) * ln2
return self._pointwise(mlir_math.log)
def log2(self, *, approx: bool = False):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError(self.mlir_dtype)
if approx:
if not ir.F32Type.isinstance(self.mlir_dtype):
raise NotImplementedError(self.mlir_dtype)
return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32"))
return self._pointwise(mlir_math.log2)
def sin(self, *, approx: bool = False):
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError
@ -1190,7 +1331,30 @@ class FragmentedArray:
from_integer = ir.IntegerType.isinstance(cur_dtype)
to_integer = ir.IntegerType.isinstance(new_dtype)
if from_float and to_float:
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
cur_ty_width = ir.FloatType(cur_dtype).width
new_ty_width = ir.FloatType(new_dtype).width
if cur_ty_width == new_ty_width:
# There is no instruction to perform conversions between two float types
# of the same width. Go through the next-larger standard type.
# TODO(bchetioui): support conversions between float types of width 8.
# Which larger type to pick will depend on the number of bits in the
# smallest exponent.
if cur_ty_width != 16:
raise NotImplementedError(
"Conversion between float types of width other than 16 not"
" supported"
)
larger_ty = ir.F32Type.get()
match self.layout:
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
shape = ir.VectorType(self.registers.flat[0].type).shape
upcast_ty = ir.VectorType.get(shape, larger_ty)
case WGMMARowFragLayout() | WGSplatFragLayout():
upcast_ty = larger_ty
case _:
raise NotImplementedError(f"Unsupported layout {self.layout}")
convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x))
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
convert = arith.truncf
else:
convert = arith.extf
@ -1423,19 +1587,34 @@ class FragmentedArray:
if create_array:
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
def store_untiled(self, ref: ir.Value):
def debug_print(self, fmt: str):
idx_fmt = ", ".join(["{}"] * len(self.shape))
@self.foreach
def _(val, idx):
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
utils.debug_print(fmt_str, *idx, val, uniform=False)
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
if not ir.MemRefType.isinstance(ref.type):
raise ValueError(ref)
def vs_unsupported():
if not vector_store:
raise NotImplementedError(
f"Can't use non-vector stores with layout {self.layout}"
)
match self.layout:
case WGMMAFragLayout():
self._store_untiled_wgmma(ref)
case WGSplatFragLayout():
vs_unsupported()
self._store_untiled_splat(ref)
case WGStridedFragLayout():
vs_unsupported()
self._store_untiled_wg_strided(ref)
case TiledLayout():
self._store_untiled_tiled(ref)
self._store_untiled_tiled(ref, vector_store=vector_store)
case _:
raise NotImplementedError(self.layout)
@ -1502,7 +1681,7 @@ class FragmentedArray:
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
memref.store(value, ref, [row, col])
def _store_untiled_tiled(self, ref: ir.Value):
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
"""Stores an array with a tiled layout. Not optimized at the moment."""
if utils.bitwidth(self.mlir_dtype) < 8:
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
@ -1510,7 +1689,7 @@ class FragmentedArray:
layout = self.layout
assert isinstance(layout, TiledLayout)
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
if ref_strides[layout.vector_dim] != 1:
if vector_store and ref_strides[layout.vector_dim] != 1:
raise NotImplementedError(
"Can't use vector stores with non-unit minormost stride"
)
@ -1524,16 +1703,30 @@ class FragmentedArray:
raise NotImplementedError(f"Unexpected ref space {ref_space}")
ptr = utils.memref_ptr(ref, memory_space=memory_space)
# Fold warp and lane offsets into the pointer once, since they are dynamic.
dyn_strides = [arith.constant(i32, s) for s in strides]
dyn_strides = [
arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :]
]
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides)
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides)
dyn_offset = arith.addi(warp_offset, lane_offset)
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
# All warp tile offsets are static and can be fused into the store.
for tile_idx, reg in np.ndenumerate(self.registers):
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
llvm.store(reg, reg_ptr)
if vector_store:
elems = [reg]
else:
index = ir.IndexType.get()
elems = [
vector.extractelement(reg, position=c(i, index))
for i in range(ir.VectorType(reg.type).shape[0])
]
for i, e in enumerate(elems):
tile_idx_local = list(tile_idx)
tile_idx_local[layout.vector_dim] += i
tile_idx_local = list(tile_idx_local)
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
llvm.store(e, reg_ptr)
def store_tiled(self, ref, swizzle: int | None):
match self.layout:
@ -1714,31 +1907,42 @@ class FragmentedArray:
ref_ty = ir.MemRefType(ref.type)
dtype = ref_ty.element_type
if ref_ty.rank % 2:
raise ValueError("Tiled refence must have even rank")
ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:])
raise ValueError("Tiled reference must have even rank")
ref_logical_rank = ref_ty.rank // 2
ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:])
ref_tiling = Tiling((ref_tiling_shape,))
ref_strides, _ = ref_ty.get_strides_and_offset()
if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape:
raise ValueError()
if len(layout.base_tile_shape) > len(ref_tiling_shape):
raise ValueError("Memory tiling must be a multiple of the register tiling")
ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):]
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
raise ValueError(
f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the"
f" register tiling ({layout.base_tile_shape})"
)
nested_ref_shape = tuple(
(ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank])
for i in range(ref_logical_rank)
)
nested_ref_strides = tuple(
(ref_strides[i], ref_strides[i + ref_logical_rank])
for i in range(ref_logical_rank)
)
tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides(
nested_ref_shape, nested_ref_strides
)
elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
# We could technically handle this case, but it would be quite complicated.
# If tiling dimensions would have to be expanded into multiple, we'd have to
# adjust the dimension indices in layouts, including expanding some of them
# into multiple indices. Note that for non-tiling dims, we allow the shape
# to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx.
if any(
len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :]
):
raise NotImplementedError("Memory and register tiling incompatible")
tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape))
elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides))
elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims]
lane_shape = [tiled_shape[d] for d in layout.lane_dims]
if elem_tiled_strides[layout.vector_dim] != 1:
raise ValueError("Stride of the vectorized dimension should be 1")
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
tiled_shape[d] = 1
full_tiling = Tiling((ref_tiling_shape, *tiling.tiles))
full_layout = dataclasses.replace(layout, tiling=full_tiling)
element_bits = mgpu.bitwidth(dtype)
if (layout.vector_length * element_bits) % 8 != 0:
@ -1779,9 +1983,11 @@ class FragmentedArray:
)
# All offsets are in units of transfer_dtype.
dyn_tiled_strides = [c(s) for s in transfer_tiled_strides]
lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides)
warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides)
dyn_tiled_strides = [
c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :]
]
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides)
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides)
dyn_offset = arith.addi(lane_offset, warp_offset)
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
raise ValueError("Tiled stores can be performed into SMEM")
@ -1809,10 +2015,23 @@ class FragmentedArray:
reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype)
offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle))
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype)
reg_idxs = [
tiling.tile_indices(full_tiling.untile_indices(idx))
for idx in indices.tolist()
]
# Here, registers are organized in an array with shape obtained by tiling
# the logical data bounds. But, the reference was tiled and so each
# logical tiled dimension can map to multiple dims in tiled_shape.
# The transform below maps this potentially higher-rank representation
# back to the lower-rank representation used by the register arrays.
def mem_idx_to_reg_idx(idx):
reg_tiled_idx = []
base_idx = 0
for dim_shape in tiled_nested_shape[:ref_logical_rank]:
dim_strides = utils.get_contiguous_strides(dim_shape)
dim_idxs = idx[base_idx:base_idx + len(dim_shape)]
base_idx += len(dim_shape)
reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides)))
# We should have fixed up all but the tiling dims.
assert base_idx == len(idx) - layout.tiled_tiling_rank
return (*reg_tiled_idx, *idx[base_idx:])
reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()]
def get_register(regs, reg_idxs=reg_idxs):
return plan.select([regs[reg_idx] for reg_idx in reg_idxs])
def update_registers(regs, new, reg_idxs=reg_idxs):

View File

@ -430,8 +430,8 @@ class LaunchContext:
gmem_ref, smem_ref = dst_ref, src_ref
if barrier is not None:
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
if arrive is not None:
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
if arrive is None:
arrive = True # Commit this copy to the async group by default
else:
raise ValueError("Only SMEM <-> GMEM copies supported")
# TODO(apaszke): This is a very approximate check. Improve it!
@ -683,7 +683,8 @@ class LaunchContext:
nvvm.cp_async_bulk_tensor_global_shared_cta(
tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate
)
nvvm.cp_async_bulk_commit_group()
if arrive:
nvvm.cp_async_bulk_commit_group()
def await_async_copy(
self, allow_groups: int, await_read_only: bool = False

View File

@ -18,18 +18,23 @@ from collections.abc import Callable, Sequence
import dataclasses
import enum
from functools import partial
import math
from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import math as mlir_math
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
import numpy as np
from . import fragmented_array as fa
from . import inference_utils
from . import layouts as layouts_lib
from . import utils
# mypy: ignore-errors
@ -192,22 +197,44 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
for op in [
arith.AddIOp, arith.AddFOp,
arith.AddIOp,
arith.AddFOp,
arith.AndIOp,
arith.BitcastOp,
arith.CmpFOp,
arith.CmpIOp,
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,
arith.ExtFOp,
arith.ExtSIOp,
arith.ExtUIOp,
arith.FPToSIOp,
arith.FPToUIOp,
arith.MaximumFOp,
arith.MaxUIOp, arith.MaxSIOp,
arith.MaxUIOp,
arith.MaxSIOp,
arith.MinimumFOp,
arith.MinUIOp, arith.MinSIOp,
arith.MulIOp, arith.MulFOp,
arith.MinUIOp,
arith.MinSIOp,
arith.MulIOp,
arith.MulFOp,
arith.OrIOp,
arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp,
arith.RemUIOp, arith.RemSIOp, arith.RemFOp,
arith.SubIOp, arith.SubFOp,
arith.TruncFOp, arith.TruncIOp,
arith.FloorDivSIOp,
arith.DivUIOp,
arith.DivFOp,
arith.RemUIOp,
arith.RemSIOp,
arith.RemFOp,
arith.SIToFPOp,
arith.UIToFPOp,
arith.SubIOp,
arith.SubFOp,
arith.TruncFOp,
arith.TruncIOp,
arith.XOrIOp,
mlir_math.ExpOp,
mlir_math.Exp2Op,
mlir_math.LogOp,
mlir_math.RsqrtOp,
mlir_math.TanhOp,
vector.LoadOp,
vector.StoreOp,
]:
@ -487,11 +514,36 @@ def infer_layout(module: ir.Module):
# propagated. However, it is possible for some operations to remain
# unannotated---for example, if there were no annotations on any operation in
# the module at the start of this function. We annotate all the remaining ops
# that should be annotated with a strided fragmented layout.
# that should be annotated with a strided fragmented layout, whose vector size
# is derived from the narrowest type and vector size used in the program. We
# make sure to derive a single vector size in order to avoid relayouts at
# lowering time.
default_vector_size = math.inf
def update_default_vector_size(op: ir.OpView):
nonlocal default_vector_size
for v in list(op.operands) + list(op.results):
if ir.VectorType.isinstance(v.type):
max_vec_size_for_v = (
np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE
)
desired_vec_size = 8 // utils.bytewidth(v.type.element_type)
default_vector_size = min(
default_vector_size, max_vec_size_for_v, desired_vec_size
)
for op in module.body:
traverse_op(op, update_default_vector_size)
if default_vector_size is None: # Nothing to annotate.
return
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
if not ir.VectorType.isinstance(ty):
return None
layout = fa.WGStridedFragLayout.from_shaped_type(ty)
layout = fa.WGStridedFragLayout(
shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size
)
return layouts_lib.to_strided_fragmented_layout_attr(layout)
def set_default_layout(op: ir.OpView):

View File

@ -0,0 +1,223 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import enum
import math
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
from . import utils
# mypy: ignore-errors
def tiled_memref_shape(ref: ir.Value):
"""Returns the 2D untiled shape and element type of a tiled 4D memref."""
ref_ty = ir.MemRefType(ref.type)
if ref_ty.rank != 4:
raise ValueError(f"Expected a 4D memref, got: {ref_ty}")
logical_shape = (
ref_ty.shape[0] * ref_ty.shape[2], ref_ty.shape[1] * ref_ty.shape[3]
)
return logical_shape, ref_ty.element_type
class Dim(enum.Enum):
K = enum.auto()
MN = enum.auto()
def create_descriptor(
ref: ir.Value,
swizzle: int,
group_size: tuple[int, int], # Instruction group size on each operand dim.
logical_k_major: bool, # False for LHS, True for RHS.
# Soft deprecated. Use small tiling instead.
large_tile: tuple[int, int] | None = None,
):
ref_ty = ir.MemRefType(ref.type)
element_bytewidth = utils.bytewidth(ref_ty.element_type)
swizzle_elems = swizzle // element_bytewidth
ref_strides, _ = ref_ty.get_strides_and_offset()
ref_byte_strides = [s * element_bytewidth for s in ref_strides]
mn_large_tile = k_large_tile = None
if logical_k_major:
_, mn_tiles, k_tiling, mn_tiling = ref_ty.shape
k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = (
ref_byte_strides
)
k_group_size, mn_group_size = group_size
if large_tile is not None:
k_large_tile, mn_large_tile = large_tile
else:
mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape
mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = (
ref_byte_strides
)
mn_group_size, k_group_size = group_size
if large_tile is not None:
mn_large_tile, k_large_tile = large_tile
IGNORED = 0
MMA_ATOM_ROWS = 8
MMA_BYTEWIDTH_K = 32
mma_width_k = MMA_BYTEWIDTH_K // element_bytewidth
# As far as I can tell (which does not seem to fully align with the way MMA is
# documented in PTX docs), MMA expects the data to be tiled into matrices
# of shape 8 x swizzle_elems, with swizzle_elems dim being the fastest
# changing. I call this submatrix an MMA atom.
#
# The role of the SMEM descriptor is to specify the striding pattern between
# those atoms. The fastest changing dimension is called the "leading"
# dimension and it specifies the stride between consecutive atoms that share
# the same coordinate along that dim. The slower dimension is called a
# "stride" dimension.
if (
large_tile is not None
and k_large_tile == k_tiling
and (mn_large_tile == mn_tiling or mn_tiles == 1 and mn_tiling < mn_large_tile)
# There are configurations where large tiles are same size as small ones.
# We use the small path since it has fewer restrictions.
and set(large_tile) != {MMA_ATOM_ROWS, swizzle_elems}
): # Large tiles.
if (
k_tiling_stride == element_bytewidth
and mn_tiling_stride == k_tiling * element_bytewidth
):
fastest_dim = Dim.K
leading_byte_offset = IGNORED # TC assumes K to be contiguous here.
# MMA atoms in a group are contiguous, so we increment by the MMA atom
# size. However, we only have one level of striding, and so if the group
# size exceeds a single large tile (and there is more than one tile) then
# that tiled dimension must be contiguous after tiles or else we would
# need another striding level.
if (
mn_tiles > 1
and mn_group_size > mn_tiling
and mn_tile_stride != math.prod(large_tile) * element_bytewidth
):
raise ValueError(
"MMA layout with large tiles that is K-fastest only supports"
" multiple MN tiles when the tiled MN dimension is a contiguous"
" stack of tiles "
f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)} * {element_bytewidth})"
)
stride_byte_offset = MMA_ATOM_ROWS * swizzle
desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous.
elif (
k_tiling_stride == k_tiling * element_bytewidth
and mn_tiling_stride == element_bytewidth
):
if k_large_tile != mn_large_tile:
raise ValueError(
"MMA layout with large tiles that is MN-fastest is only supported"
" when the tiling is square"
)
fastest_dim = Dim.MN
# Next swizzle atom with the same K coordinate is in the next MN tile.
leading_byte_offset = mn_tile_stride
# MMA atoms in a group are contiguous and a group does not exceed a tile.
assert k_large_tile == k_group_size
stride_byte_offset = MMA_ATOM_ROWS * swizzle
# Each row is swizzle bytes wide, and we read mma_width_k rows at a time.
assert mn_large_tile == swizzle // element_bytewidth
desc_k_stride = mma_width_k * swizzle
else:
raise ValueError("MMA tiles must be contiguous")
else: # Small tiles.
if k_tiling_stride > mn_tiling_stride:
slower_tiling, faster_tiling = k_tiling, mn_tiling
else:
faster_tiling, slower_tiling = k_tiling, mn_tiling
if slower_tiling != MMA_ATOM_ROWS or faster_tiling != swizzle_elems:
raise ValueError(
f"Tiling should be ({MMA_ATOM_ROWS}, swizzle_elems) where"
f" swizzle_elems = swizzle // bytewidth(dtype) (= {swizzle} //"
f" {element_bytewidth} = {swizzle_elems}), but got ({slower_tiling},"
f" {faster_tiling})"
)
if k_tiling_stride == element_bytewidth and mn_tiling_stride == swizzle:
fastest_dim = Dim.K
leading_byte_offset = IGNORED # TC assumes K to be contiguous here.
stride_byte_offset = mn_tile_stride
desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous.
elif k_tiling_stride == swizzle and mn_tiling_stride == element_bytewidth:
fastest_dim = Dim.MN
leading_byte_offset = mn_tile_stride
stride_byte_offset = k_tile_stride
k_tiles_per_mma = mma_width_k // MMA_ATOM_ROWS
desc_k_stride = k_tile_stride * k_tiles_per_mma
else:
raise ValueError("MMA tiles must be contiguous")
desc_base = encode_descriptor(
ref,
leading_byte_offset=leading_byte_offset,
stride_byte_offset=stride_byte_offset,
swizzle=swizzle,
)
mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling)
assert not rem
mn_group_stride = mn_tile_stride * mn_tiles_per_group
k_tiles_per_group, rem = divmod(k_group_size, k_tiling)
assert not rem
k_group_stride = k_tile_stride * k_tiles_per_group
return (
(desc_base, desc_k_stride),
(mn_group_stride, k_group_stride),
fastest_dim,
)
def encode_addr(x: int):
result = (x & 0x3FFFF) >> 4
if result << 4 != x:
raise ValueError(f"Cannot encode value in an MMA descriptor: {x}")
return result
def encode_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
const_init: int = 0,
):
i64 = ir.IntegerType.get_signless(64)
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, 3))
c = lambda x: arith.constant(i64, x)
if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
swizzle_encoding = 0
elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
swizzle_encoding = 1
elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle:
swizzle_encoding = 2
elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle:
swizzle_encoding = 3
else:
raise NotImplementedError(swizzle)
encoded_base_addr = llvm.lshr(llvm.and_(ptr_val, c(0x3FFFF)), c(4))
# We ignore the offset
desc_const = (
const_init
| (encode_addr(leading_byte_offset) << 16)
| (encode_addr(stride_byte_offset) << 32)
)
desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const))
desc = llvm.or_(encoded_base_addr, desc)
return desc

View File

@ -18,7 +18,6 @@ from __future__ import annotations
import dataclasses
import math
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
@ -27,7 +26,7 @@ import numpy as np
from . import utils
from . import fragmented_array as fa
from . import _wgmma
from . import mma_utils
from .launch_context import LaunchContext
# MyPy does a terrible job with the MLIR API.
@ -37,21 +36,6 @@ from .launch_context import LaunchContext
TMEM_ROWS = 128
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
def create_smem_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
):
return _wgmma.create_descriptor(
memref_arg,
leading_byte_offset,
stride_byte_offset,
swizzle,
memory_space=3,
const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)
def create_instr_descriptor(
m: int,
n: int,
@ -100,70 +84,126 @@ def mma(
collective: bool = False,
):
i64 = ir.IntegerType.get_signless(64)
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
if a_swizzle != b_swizzle:
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
swizzle = a_swizzle
num_cta = 2 if collective else 1
# Step 1. Establish the shape and element type of the operation.
if not ir.MemRefType.isinstance(a.type):
raise ValueError(f"A must be a memref, got {a.type}")
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
if a_swizzle != b_swizzle:
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
(k, n), element_type = mma_utils.tiled_memref_shape(b)
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
if k != k2:
raise ValueError(
"MMA requires A and B to have the same contraction dimension (K),"
f" got: {k2} and {k}"
)
if element_type != element_type2:
raise ValueError(
"MMA requires A and B to have the same element type, got:"
f" {element_type2} and {element_type}"
)
if d.shape != (m, n * num_cta):
raise ValueError(
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
)
f32 = ir.F32Type.get()
if element_type == f32 or element_type == ir.BF16Type.get():
if d.dtype != f32:
raise ValueError(
f"MMA with element type {element_type} only supports accumulators"
f" of type f32, but got: {d.dtype}"
)
elif element_type == ir.F16Type.get():
if d.dtype != element_type and d.dtype != f32:
raise ValueError(
"MMA with element type f16 only supports accumulators of type f32"
f" or f16, but got: {d.dtype}"
)
m_group_size = d.layout.elements_in_tile[0]
if m_group_size != 128:
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
# instructions must be issued in groups of the same width as the swizzle.
m_group_elems = d.layout.elements_in_tile[0]
if m_group_elems != 128:
raise NotImplementedError("Only 128-row accumulators supported for now")
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_groups, k_groups, n_groups),
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
mma_params,
) = _wgmma._validate_mma(
a,
b,
a_swizzle,
m_group_size=m_group_size,
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)
n_group_size = n // n_groups
if n > 512:
raise ValueError(f"N is too big: at most 512 is supported, but got {n}")
num_cta = 2 if collective else 1
k_group_elems = swizzle // utils.bytewidth(element_type)
if n % 8:
raise ValueError(f"N must be a multiple of 8, got: {n}")
elif n > 256 and n != 512:
raise ValueError("Only N below 256 or N=512 are supported")
if num_cta == 2 and n > 256:
raise NotImplementedError(
"N is too big for collective MMA. Only up to 256 is supported."
)
n_group_elems = min(n, 256)
if m % m_group_elems:
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
if k % k_group_elems:
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
if n % n_group_elems:
raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}")
m_groups = m // m_group_elems
k_groups = k // k_group_elems
n_groups = n // n_group_elems
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
wgmma_element_type = (
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
)
# TODO(apaszke): Verify that the cluster shape matches the expectation of
# collective MMA.
expected_acc_shape = (m, n * num_cta)
if d.shape != expected_acc_shape:
raise ValueError(
f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}"
)
# Step 3. Compute the operand descriptors.
(
(a_desc_base, a_k_instr_stride),
(a_m_group_stride, a_k_group_stride),
a_fastest,
) = mma_utils.create_descriptor(
a,
swizzle=swizzle,
group_size=(m_group_elems, k_group_elems),
logical_k_major=False,
)
(
(b_desc_base, b_k_instr_stride),
(b_n_group_stride, b_k_group_stride),
b_fastest,
) = mma_utils.create_descriptor(
b,
swizzle=swizzle,
group_size=(k_group_elems, n_group_elems),
logical_k_major=True,
)
# Step 4. Issue the instructions.
true = arith.constant(ir.IntegerType.get_signless(1), 1)
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64))
if m_groups != 1:
raise NotImplementedError("D needs to be sliced")
acc = accumulate if ki == 0 else true
_do_mma(
d.slice(
slice(None), utils.ds(ni * n_group_size, n_group_size)
slice(None), utils.ds(ni * n_group_elems, n_group_elems)
).address,
a_mk,
b_nk,
d_type=ir.F32Type.get(),
m=m_group_size,
m=m_group_elems,
n=n_group_elems,
collective=collective,
**mma_params,
a_transpose=a_fastest != mma_utils.Dim.K,
b_transpose=b_fastest != mma_utils.Dim.K,
a_k_stride=a_k_instr_stride,
b_k_stride=b_k_instr_stride,
accumulate=acc,
swizzle=swizzle,
element_type=wgmma_element_type,
)

View File

@ -1172,4 +1172,32 @@ def getelementptr(
def dyn_dot(x, y):
assert len(x) == len(y)
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
def shfl_bfly(x: ir.Value, distance: int | ir.Value):
i32 = ir.IntegerType.get_signless(32)
if isinstance(distance, int):
distance = c(distance, i32)
assert x.type == i32
return nvvm.shfl_sync(
i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly,
)
def bitcast(x: ir.Value, new_type: ir.Type):
if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type):
new_type = ir.IntegerType(new_type)
x_ty = ir.VectorType(x.type)
assert new_type.width == bitwidth(x_ty.element_type) * math.prod(x_ty.shape)
i0 = arith.ConstantOp.create_index(0)
return vector.extractelement(
vector.bitcast(ir.VectorType.get((1,), new_type), x), position=i0
)
if ir.IntegerType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
new_type = ir.VectorType(new_type)
x_ty = ir.IntegerType(x.type)
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
raise ValueError(f"Can't bitcast {x.type} to {new_type}")

View File

@ -14,13 +14,10 @@
# ==============================================================================
import dataclasses
import enum
import functools
import itertools
from typing import Any
import jax
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
@ -29,6 +26,7 @@ from jaxlib.mlir.dialects import vector
import numpy as np
from . import fragmented_array as fa
from . import mma_utils
from . import utils
# mypy: ignore-errors
@ -84,60 +82,6 @@ class WGMMAAccumulator:
return cls(_value=value[0], _sync=False)
def wgmma_encode(x: int):
result = (x & 0x3FFFF) >> 4
if result << 4 != x:
raise ValueError(f"Cannot encode value in a WGMMA descriptor: {x}")
return result
def llvm_add(x, y):
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
def create_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
memory_space: int | None = None,
const_init: int = 0,
):
i64 = ir.IntegerType.get_signless(64)
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
swizzle_encoding = 0
elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
swizzle_encoding = 1
elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle:
swizzle_encoding = 2
elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle:
swizzle_encoding = 3
else:
raise NotImplementedError(swizzle)
encoded_base_addr = llvm.LShrOp(
llvm.AndOp(ptr_val, c(0x3FFFF, i64)).result, c(4, i64)
)
# We ignore the offset
desc_const = (
const_init
| (wgmma_encode(leading_byte_offset) << 16)
| (wgmma_encode(stride_byte_offset) << 32)
)
desc = llvm.or_(
arith.shli(c(swizzle_encoding, i64), c(62, i64)), c(desc_const, i64)
)
desc = llvm.or_(encoded_base_addr.result, desc)
return desc
def _unpack_i32(vec_ty, r):
i32 = ir.IntegerType.get_signless(32)
return vector.bitcast(
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
)
def _supported_wgmma_types(dtype, abtype) -> bool:
input_types_are = lambda ty: ty.isinstance(abtype)
if ir.F32Type.isinstance(dtype):
@ -271,14 +215,14 @@ def wgmma_m64(
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
else:
if i > 0:
a = llvm_add(
a = _llvm_add(
a,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
)
a_args = [a]
# Advance the B descriptor.
if i > 0:
b_descriptor = llvm_add(
b_descriptor = _llvm_add(
b_descriptor,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
)
@ -297,262 +241,6 @@ def wgmma_m64(
return to_acc_vec_regs(acc_regs)
class WGMMALayout(enum.Enum):
ROW_MAJOR = enum.auto()
COL_MAJOR = enum.auto()
def _validate_mma(
a: Any,
b: ir.Value,
swizzle: int,
m_group_size: int, # The M used by a single instruction.
descriptor_const_init: int = 0,
):
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
# instruction's K width.
if swizzle < 32:
raise ValueError(f"Unsupported swizzle: {swizzle}")
# Get A type.
if a_in_smem := isinstance(a, ir.Value):
if not ir.MemRefType.isinstance(a.type):
raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}")
a_ty = ir.MemRefType(a.type)
a_element_type = a_ty.element_type
a_shape = tuple(a_ty.shape)
if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
raise ValueError("A must be in workgroup memory when it's a reference")
if len(a_shape) != 4:
raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}")
elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"):
a_element_type = a.mlir_dtype
a_shape = a.shape
else:
raise NotImplementedError(f"Unsupported A type: {type(a)}")
# Get B type (always a reference).
b_ty = ir.MemRefType(b.type)
if b_ty.rank != 4:
raise ValueError(f"B must be 4D, got rank {b_ty.rank}")
# Veirfy element types and compute the tiling.
if (element_type := a_element_type) != b_ty.element_type:
raise ValueError(
f"A and B must have the same element type, got: {a_element_type} and"
f" {b_ty.element_type}"
)
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
if element_type not in supported_types:
raise ValueError(a_element_type)
element_bytewidth = bytewidth(element_type)
swizzle_elems = swizzle // element_bytewidth
# Verify the shape and strides of B are as expected.
b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape
k = b_k_tiles * b_k_tiling
n = n_tiles * n_tiling
b_strides, _ = b_ty.get_strides_and_offset()
b_byte_strides = [s * element_bytewidth for s in b_strides]
b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides
if (
b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth
and n_tiles != 1 # When there's only one tile, we never jump between them
):
raise ValueError("B tiles must be contiguous along the N dimension")
if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest
b_order = WGMMALayout.ROW_MAJOR
# This first case (n_tiles == 1) is to allow the somewhat weird case of
# loading a small amount of N-fastest data, that needs to be padded to a
# larger tile due to swizzle. In this case we allow slicing the big tile
# before WGMMA to avoid unnecessary compute on padding.
if n_tiles == 1:
if n_tiling % 8:
raise ValueError("N tile size must be a multiple of 8")
elif n_tiling != swizzle_elems:
raise ValueError(
"Row major RHS (N-fastest) requires the N tile size to be equal to"
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
)
if b_k_tiling not in {8, swizzle_elems}:
raise ValueError(
"Row major RHS (N-fastest) requires the K tile size to be either"
f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}"
)
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
b_order = WGMMALayout.COL_MAJOR
if b_k_tiling != swizzle_elems:
raise ValueError(
"Column major RHS (K-fastest) requires the K tile size to be equal"
f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}"
)
# See the explanation in the N-fastest case when n_tiles == 1.
if n_tiles == 1:
if n_tiling % 8:
raise ValueError("N tile size must be a multiple of 8")
elif n_tiling not in {8, swizzle_elems}:
raise ValueError(
"Column major RHS (K-fastest) requires the N tile size to be either"
f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}"
)
else:
raise ValueError(b_byte_strides)
if n > 256 and n % 256:
raise ValueError(
f"N group size must be a multiple of 256 when larger than 256, got: {n}"
)
k_group_size = swizzle_elems
n_group_size = min(n, 256)
b_k_tiles_per_group = k_group_size // b_k_tiling
b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group
n_tiles_per_group = n_group_size // n_tiling
b_n_group_stride = b_n_byte_stride * n_tiles_per_group
# Verify the shape and strides of A are as expected.
if not a_in_smem:
m = a_shape[0]
a_order = a_m_group_stride = a_k_group_stride = None
else:
a_ty = ir.MemRefType(a.type)
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
m = m_tiles * m_tiling
# TODO(apaszke): I'm not actually convinced that we need this check.
if m_tiling != m_group_size:
raise ValueError(
f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}"
)
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
raise ValueError(a_ty.shape)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [
s * element_bytewidth for s in a_strides
]
if a_tile_byte_strides == [swizzle, element_bytewidth]:
a_order = WGMMALayout.ROW_MAJOR
elif a_tile_byte_strides == [element_bytewidth, swizzle]:
a_order = WGMMALayout.COL_MAJOR
else:
raise ValueError(a_strides)
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
a_m_tiles_per_group = m_group_size // m_tiling
a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group
a_k_tiles_per_group = k_group_size // a_k_tiling
a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
# This is the number of rows until consecutive repeats of the swizzle pattern.
swizzle_pattern_rows = swizzle // 16
# A swizzle atom is a 2D matrix with the dimensions below.
swizzle_atom_bytes = swizzle_pattern_rows * 128
# Here "leading" refers to the fastest changing dimension. There are two
# strides we have to define per value:
# Leading byte offset (LBO)
# K-fastest: ignored
# MN-fastest: stride between consecutive swizzle atoms that share the same
# K coordinate.
# Stride byte offset (SBO)
# As far as I can tell this is just the offset between two consecutive
# swizzle atoms along the non-leading dimension.
IGNORED = 0
a_desc_fields = dict(
# I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it
# is documented in the PTX docs, and my best explanation for the other
# case is that the instruction has a fixed shape and so it does not care
# about strides. It's possible that it's an artifact of the fact that we
# use tiling of 64.
leading_byte_offset=IGNORED,
stride_byte_offset=swizzle_atom_bytes,
swizzle=swizzle,
memory_space=3,
)
# If B is N-fastest, all swizzle atoms within a tile share the same N
# coordinate, so we simply take the stride between consecutive N tiles.
# If B is K-fastest, all swizzle atoms within a tile share the same K
# coordinate, which forces us to lay out the tiles in N-fastest order or else
# they would have uneven strides.
b_desc_fields = dict(
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
# N tiles are contiguous, so the next N swizzle atom follows immediately.
# K tiles are not contiguous, so we take the stride between them.
stride_byte_offset=swizzle_atom_bytes
if b_k_fastest or b_k_tiling == swizzle_elems
else b_k_byte_stride,
swizzle=swizzle,
memory_space=3,
)
# The K strides indicate the stride between the consecutive places where all
# coordinates are 0 except for K being incremented by the instruction width.
# If an input is K-fastest, we increment the descriptor by 32 bytes, since
# that is the K-width of all MMA instructions.
if b_k_fastest:
b_k_wgmma_stride = 32
elif b_k_tiling == swizzle_elems:
# When B is N-fastest and we use the large square tiling, the relevant
# slices all fall within the first tile. A single MMA instruction for 16-bit
# types reads a subtile of shape 16x(swizzle bytes), giving us the necessary
# expression.
assert n_tiling == swizzle_elems or n_tiles == 1
b_k_wgmma_stride = swizzle * 16
else:
# If we use the small non-square tiling and N-fastest layout, each tile only
# contains a single swizzle atom with the K coordinate. But, each tile has
# 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles.
b_k_wgmma_stride = b_k_byte_stride * 2
wgmma_params = dict(
a_transpose=not a_k_fastest,
b_transpose=not b_k_fastest,
# TODO(apaszke): This explanation is quite bad. We should better figure
# out how to do LHS transposes.
# We only support swizzle=128 for M-fastest A. In this case the tile is
# swizzle x 64 (= swizzle elems) and so we just take a quarter of its size.
a_k_stride=32 if a_k_fastest else swizzle * 16,
b_k_stride=b_k_wgmma_stride,
swizzle=swizzle,
n=n_group_size,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
else element_type,
)
if not a_in_smem:
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
a_desc_base = None
else:
a_desc_base = create_descriptor(
a, **a_desc_fields, const_init=descriptor_const_init
)
b_desc_base = create_descriptor(
b, **b_desc_fields, const_init=descriptor_const_init
)
if m % m_group_size:
raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}")
m_groups = m // m_group_size
if k % k_group_size:
raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}")
k_groups = k // k_group_size
if n % n_group_size:
raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}")
n_groups = n // n_group_size
return (
a_desc_base,
b_desc_base,
(m, k, n),
(m_groups, k_groups, n_groups),
# Group strides are always in bytes!
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
wgmma_params,
)
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
# transpositions from memref strides.
def wgmma(
acc: WGMMAAccumulator,
a: fa.FragmentedArray | ir.Value,
@ -570,61 +258,129 @@ def wgmma(
The refs must be contiguous or be contiguous except for having their two minor
dimensions swapped.
"""
a_in_regs = isinstance(a, fa.FragmentedArray)
if not a_in_regs and not ir.MemRefType.isinstance(a.type):
raise ValueError(f"Unsupported A type: {type(a)}")
# Step 1. Establish the shape and element type of the operation.
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
m_group_size = 64 # Hopper has a fixed M instruction shape.
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_groups, k_groups, n_groups),
(a_m_group_stride, a_k_group_stride, b_k_group_stride, _),
wgmma_params,
) = _validate_mma(a, b, swizzle, m_group_size=m_group_size)
if n_groups > 1:
raise ValueError("N is too big for WGMMA. Only up to 256 is supported.")
if a_in_regs:
(k, n), element_type = mma_utils.tiled_memref_shape(b)
if a_in_regs := isinstance(a, fa.FragmentedArray):
m, k2 = a.shape
element_type2 = a.mlir_dtype
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
)
if a.shape[0] % m_group_size:
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
a_m_group_stride = a_k_group_stride = None
elif ir.MemRefType.isinstance(a.type):
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
else:
raise ValueError(f"Unsupported A type: {type(a)}")
if k != k2:
raise ValueError(
"WGMMA requires A and B to have the same contraction dimension (K),"
f" got: {k2} and {k}"
)
if element_type != element_type2:
raise ValueError(
"WGMMA requires A and B to have the same element type, got:"
f" {element_type2} and {element_type}"
)
if acc.value.shape != (m, n):
raise ValueError(
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
)
f32 = ir.F32Type.get()
if element_type == f32 or element_type == ir.BF16Type.get():
if acc.value.mlir_dtype != f32:
raise ValueError(
f"WGMMA with element type {element_type} only supports accumulators"
f" of type f32, but got: {acc.value.mlir_dtype}"
)
elif element_type == ir.F16Type.get():
if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32:
raise ValueError(
"WGMMA with element type f16 only supports accumulators of type f32"
f" or f16, but got: {acc.value.mlir_dtype}"
)
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
# instructions must be issued in groups of the same width as the swizzle.
m_group_elems = 64 # Hopper has a fixed M instruction shape.
k_group_elems = swizzle // utils.bytewidth(element_type)
if n > 256 or n % 8:
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
n_group_elems = n # We assume only one N group below.
if m % m_group_elems:
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
if k % k_group_elems:
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
m_groups = m // m_group_elems
k_groups = k // k_group_elems
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
wgmma_element_type = (
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
)
# Step 3. Compute the operand descriptors.
if a_in_regs:
a_desc_base = a_m_group_stride = a_k_group_stride = None
a_instr_params = dict(a_transpose=None, a_k_stride=None)
else:
(
(a_desc_base, a_k_instr_stride),
(a_m_group_stride, a_k_group_stride),
a_fastest,
) = mma_utils.create_descriptor(
a,
swizzle=swizzle,
large_tile=(m_group_elems, k_group_elems),
group_size=(m_group_elems, k_group_elems),
logical_k_major=False,
)
a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
a_k_stride=a_k_instr_stride)
(
(b_desc_base, b_k_instr_stride),
(b_n_group_stride, b_k_group_stride),
b_fastest,
) = mma_utils.create_descriptor(
b,
swizzle=swizzle,
large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n.
group_size=(k_group_elems, n_group_elems),
logical_k_major=True,
)
del b_n_group_stride # We only support one N group.
# Step 4. Issue the instructions.
if a_in_regs:
a = wgmma_fence(a) # Make sure the registers are ready.
i64 = ir.IntegerType.get_signless(64)
new_acc_regs = acc.value.registers.copy()
k_group_size = k // k_groups
for mi in range(m_groups):
for ki in range(k_groups):
if a_in_regs:
a_mk = a[
mi * m_group_size : (mi + 1) * m_group_size,
ki * k_group_size : (ki + 1) * k_group_size,
mi * m_group_elems : (mi + 1) * m_group_elems,
ki * k_group_elems : (ki + 1) * k_group_elems,
]
else:
a_mk = llvm_add(
a_desc_base,
c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64),
a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = _llvm_add(
a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
)
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
b_k = _llvm_add(
b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
)
new_acc_regs[mi : mi + 1] = wgmma_m64(
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
new_acc_regs[mi : mi + 1],
a_mk,
b_k,
swizzle=swizzle,
n=n_group_elems,
element_type=wgmma_element_type,
b_transpose=b_fastest != mma_utils.Dim.K,
b_k_stride=b_k_instr_stride,
**a_instr_params,
)
return WGMMAAccumulator(
_value=fa.FragmentedArray(
@ -668,3 +424,14 @@ def _as_i32_reg(v):
def _lc(x):
i32 = ir.IntegerType.get_signless(32)
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
def _llvm_add(x, y):
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
def _unpack_i32(vec_ty, r):
i32 = ir.IntegerType.get_signless(32)
return vector.bitcast(
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
)

View File

@ -53,6 +53,7 @@ from jax._src.pallas.primitives import max_contiguous as max_contiguous
from jax._src.pallas.primitives import multiple_of as multiple_of
from jax._src.pallas.primitives import num_programs as num_programs
from jax._src.pallas.primitives import program_id as program_id
from jax._src.pallas.primitives import reciprocal as reciprocal
from jax._src.pallas.primitives import run_scoped as run_scoped
from jax._src.pallas.primitives import store as store
from jax._src.pallas.primitives import swap as swap

View File

@ -35,6 +35,7 @@ from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arri
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout

View File

@ -0,0 +1,721 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TPU-Friendly Ragged Paged Attention kernel.
This kernel offers a highly optimized implementation of ragged paged attention,
specifically designed for TPU and compatible with a wide range of model
specifications. It supports mixed prefill and decoding, enhancing throughput
during inference.
"""
import functools
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
class MultiPageAsyncCopyDescriptor:
"""Descriptor for async copy of multiple K/V pages from HBM."""
def __init__(
self,
pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim]
vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
sem,
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
offset, # [seq_idx, kv_pages_start]
):
self._vmem_buf = vmem_buf
seq_id, kv_pages_start = offset
self._async_copies = [
pltpu.make_async_copy(
pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]],
vmem_buf.at[i],
sem,
)
for i in range(vmem_buf.shape[0])
]
def start(self):
"""Starts the async copies."""
for async_copy in self._async_copies:
async_copy.start()
def wait(self):
for async_copy in self._async_copies:
async_copy.wait()
return self._vmem_buf
def ref_ragged_paged_attention(
queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1],
*,
sm_scale: float = 1.0,
mask_value: float = DEFAULT_MASK_VALUE,
):
_, _, num_kv_heads, head_dim = k_pages.shape
num_q_heads = queries.shape[1]
assert num_q_heads % num_kv_heads == 0
num_query_per_kv = num_q_heads // num_kv_heads
outputs = []
for i in range(num_seqs[0]):
q_start = cu_q_lens[i]
q_end = cu_q_lens[i + 1]
q_len = q_end - q_start
kv_len = kv_lens[i]
indices = page_indices[i]
q = queries[q_start:q_end]
k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
k = jnp.repeat(k, num_query_per_kv, axis=1)
v = jnp.repeat(v, num_query_per_kv, axis=1)
attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32)
attn *= sm_scale
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
attn += jnp.where(q_span < kv_span, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
outputs.append(out)
return jnp.concatenate(outputs, axis=0)
# Expect to run these checkes during runtime.
def validate_inputs_on_runtime(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs, # i32[1]
):
check_inputs_shapes(
q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs
)
max_num_batched_tokens = q.shape[0]
page_size = k_pages.shape[1]
max_num_seqs, pages_per_seq = page_indices.shape
if num_seqs[0] > max_num_seqs:
raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
max_kv_len = jnp.max(kv_lens)
min_pages_per_seq = ceil_div(max_kv_len, page_size)
if pages_per_seq < min_pages_per_seq:
raise ValueError(
f"{pages_per_seq=} must be greater or equal to"
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}."
)
if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
raise ValueError(
f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
f" {max_num_batched_tokens=}."
)
for i in range(num_seqs[0]):
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
kv_len = kv_lens[i]
if q_len > kv_len:
raise ValueError(
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
)
# Expect to run these checks during compile time.
def check_inputs_shapes(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs, # i32[1]
):
_, num_q_heads, head_dim = q.shape
_, _, num_kv_heads, head_dim_k = k_pages.shape
max_num_seqs, _ = page_indices.shape
if num_seqs.shape != (1,):
raise ValueError(f"{num_seqs.shape=} must be (1,)")
if k_pages.shape != v_pages.shape:
raise ValueError(
f"{k_pages.shape=} and {v_pages.shape=} must have the same shape."
)
if head_dim_k != head_dim:
raise ValueError(
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
)
if kv_lens.shape != (max_num_seqs,):
raise ValueError(
f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
" `max_num_seqs` is `page_indices.shape[0]`."
)
if cu_q_lens.shape != (max_num_seqs + 1,):
raise ValueError(
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
" `max_num_seqs` is `page_indices.shape[0]`."
)
if (
kv_lens.dtype != jnp.int32
or page_indices.dtype != jnp.int32
or cu_q_lens.dtype != jnp.int32
):
raise ValueError(
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
f" {cu_q_lens.dtype=}."
)
if num_q_heads % num_kv_heads != 0:
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
def ragged_paged_attention_kernel(
# Prefetch
kv_lens_ref, # [max_num_seqs]
page_indices_ref, # [max_num_seqs, pages_per_seq]
cu_q_lens_ref, # [max_num_seqs + 1]
seq_buf_idx_ref,
# TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
num_seqs_ref,
# Input
q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim]
# Output
o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
# Scratch
k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim]
sems, # [2, 2]
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
*,
sm_scale: float,
mask_value: float,
):
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
num_seqs = num_seqs_ref[0]
_, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape
num_kv_per_blk = num_kv_pages_per_blk * page_size
num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
heads_blk_idx, q_blk_idx = (
pl.program_id(0),
pl.program_id(1),
)
num_heads_blks = pl.num_programs(0)
init_seq_idx = seq_buf_idx_ref[0]
init_buf_idx = seq_buf_idx_ref[1]
q_len_start = q_blk_idx * num_q_per_blk
q_len_end = q_len_start + num_q_per_blk
def create_kv_async_copy_descriptors(
heads_blk_idx, seq_idx, kv_blk_idx, buf_idx
):
offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk)
heads_start = heads_blk_idx * num_kv_heads_per_blk
async_copy_k = MultiPageAsyncCopyDescriptor(
k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :],
k_bufs.at[buf_idx],
sems.at[buf_idx, 0],
page_indices_ref,
offset,
)
async_copy_v = MultiPageAsyncCopyDescriptor(
v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :],
v_bufs.at[buf_idx],
sems.at[buf_idx, 1],
page_indices_ref,
offset,
)
return async_copy_k, async_copy_v
# TODO(jevinjiang): Add these to Mosaic:
# 1. Support arbitrary strided load/store for any dtype.
# 2. Support arbitrary strided load/store for any last dimension.
def strided_load_kv(ref, start, step):
if ref.dtype == jnp.float32:
return ref[start::step, :]
packing = get_dtype_packing(ref.dtype)
assert ref.dtype == jnp.bfloat16
assert step % packing == 0
b_start = start // packing
b_offset = start % packing
b_step = step // packing
b_ref = ref.bitcast(jnp.int32)
b = b_ref[b_start::b_step, :]
bw = 32 // packing
b = jnp.right_shift(b, bw * b_offset)
b = jnp.left_shift(b, bw * (packing - 1))
return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16)
def fold_on_2nd_minor(vec):
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
assert len(vec.shape) >= 2
last_dim = vec.shape[-1]
packing = get_dtype_packing(vec.dtype)
if vec.shape[-2] % packing != 0:
vec = vec.astype(jnp.float32)
return vec.reshape(-1, last_dim)
@pl.when(heads_blk_idx + q_blk_idx == 0)
def prefetch_first_kv_blk():
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
heads_blk_idx, init_seq_idx, 0, init_buf_idx
)
async_copy_k.start()
async_copy_v.start()
def is_cur_q_blk_needed(q_states):
done, cur_seq_idx, _ = q_states
return jnp.logical_and(done == 0, cur_seq_idx < num_seqs)
def compute_with_cur_q_blk(q_states):
done, cur_seq_idx, cur_buf_idx = q_states
q_start = cu_q_lens_ref[cur_seq_idx]
q_end = cu_q_lens_ref[cur_seq_idx + 1]
q_len = q_end - q_start
kv_len = kv_lens_ref[cur_seq_idx]
def get_next_prefetch_ids(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
):
next_kv_blk_idx = kv_blk_idx + 1
is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
next_kv_blk_idx = lax.select(
is_last_kv_blk,
0,
next_kv_blk_idx,
)
is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
next_seq_idx = lax.select(
is_last_kv_blk,
lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx),
cur_seq_idx,
)
is_last_seq = next_seq_idx == num_seqs
next_seq_idx = lax.select(
is_last_seq,
0,
next_seq_idx,
)
next_heads_blk_idx = lax.select(
is_last_seq,
heads_blk_idx + 1,
heads_blk_idx,
)
next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
def flash_attention(
q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
k, # [num_kv_per_blk, head_dim]
v, # [num_kv_per_blk, head_dim]
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
*,
kv_blk_idx,
):
assert q.shape == (
num_q_per_blk * num_q_heads_per_kv_head,
head_dim,
)
assert k.shape == (
num_kv_per_blk,
head_dim,
), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}"
assert v.shape == (num_kv_per_blk, head_dim)
assert head_m_ref.shape == (
num_q_per_blk * num_q_heads_per_kv_head,
128,
)
assert head_l_ref.shape == (
num_q_per_blk * num_q_heads_per_kv_head,
128,
)
assert head_o_ref.shape == (
num_q_per_blk,
num_q_heads_per_kv_head,
head_dim,
)
kv_len_start = kv_blk_idx * num_kv_per_blk
def masked_store(ref, val, start, end, group=1):
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
mask = jnp.logical_and(iota >= start, iota < end)
pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask)
qk = (
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
* sm_scale
)
store_start = jnp.maximum(q_start - q_len_start, 0)
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
@pl.when(kv_blk_idx == 0)
def init_scratch_ref():
masked_store(
head_m_ref,
jnp.full_like(head_m_ref, -jnp.inf),
store_start,
store_end,
num_q_heads_per_kv_head,
)
masked_store(
head_l_ref,
jnp.zeros_like(head_l_ref),
store_start,
store_end,
num_q_heads_per_kv_head,
)
masked_store(
head_o_ref,
jnp.zeros_like(head_o_ref),
store_start,
store_end,
)
row_ids = (
(kv_len - q_len)
+ q_len_start
- q_start
+ jax.lax.broadcasted_iota(
jnp.int32,
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
0,
)
// num_q_heads_per_kv_head
)
col_ids = kv_len_start + jax.lax.broadcasted_iota(
jnp.int32,
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
1,
)
causal_mask = row_ids < col_ids
qk += jnp.where(causal_mask, mask_value, 0.0)
m_curr = jnp.max(qk, axis=1, keepdims=True)
s_curr = jnp.exp(qk - m_curr)
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
lm_store_shape = head_m_ref.shape
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
l_curr = jnp.broadcast_to(
s_curr.sum(axis=1, keepdims=True), lm_store_shape
)
m_prev = head_m_ref[...]
l_prev = head_l_ref[...]
m_next = jnp.maximum(m_prev, m_curr)
masked_store(
head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head
)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_alpha = alpha * l_prev
l_next = l_alpha + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
masked_store(
head_l_ref,
l_next_safe,
store_start,
store_end,
num_q_heads_per_kv_head,
)
def broadcast_to_shape(arr, shape):
if arr.shape == shape:
return arr
assert len(arr.shape) == len(shape)
assert arr.shape[0] == shape[0]
assert shape[1] % arr.shape[1] == 0
# no-op concatenation.
return jnp.concatenate(
[arr for _ in range(shape[1] // arr.shape[1])], axis=1
)
o_curr = head_o_ref[...].reshape(-1, head_dim)
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
beta = broadcast_to_shape(beta, qkv.shape)
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
out = lax.div(
l_alpha * o_curr + beta * qkv,
l_next_safe,
).astype(head_o_ref.dtype)
masked_store(
head_o_ref,
out.reshape(head_o_ref.shape),
store_start,
store_end,
)
def is_valid_kv_blk_in_cur_seq(kv_states):
kv_blk_idx, _ = kv_states
return kv_blk_idx * num_kv_per_blk < kv_len
def compute_with_kv_blk_in_cur_seq(kv_states):
kv_blk_idx, cur_buf_idx = kv_states
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
get_next_prefetch_ids(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
)
)
@pl.when(next_heads_blk_idx < num_heads_blks)
def prefetch_next_kv_blk():
# TODO(jevinjiang): reuse the same buffer if it is already prefetched!
# TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
# DMA to fixed size buffer!
next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors(
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
)
next_async_copy_k.start()
next_async_copy_v.start()
cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
)
kv_to_load_shape = (
num_kv_pages_per_blk * page_size * num_kv_heads_per_blk,
head_dim,
)
k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape)
v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape)
for kv_head_idx in range(num_kv_heads_per_blk):
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
# TODO(jevinjiang): extra handlig for packed type that can start at
# unaligned position!
q = fold_on_2nd_minor(
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
)
k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk)
v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk)
flash_attention(
q,
k,
v,
l_ref.at[kv_head_idx],
m_ref.at[kv_head_idx],
o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :],
kv_blk_idx=kv_blk_idx,
)
return kv_blk_idx + 1, next_buf_idx
_, next_buf_idx = lax.while_loop(
is_valid_kv_blk_in_cur_seq,
compute_with_kv_blk_in_cur_seq,
(0, cur_buf_idx), # (kv_blk_idx, buf_idx)
)
next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx)
done = lax.select(q_end < q_len_end, done, 1)
return done, next_seq_idx, next_buf_idx
_, seq_idx, buf_idx = lax.while_loop(
is_cur_q_blk_needed,
compute_with_cur_q_blk,
(0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
)
# Reset seq_idx for next kv_heads_blk if run out of seqs!
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
seq_buf_idx_ref[1] = buf_idx
def ceil_div(a, b):
assert b != 0
return (a + b - 1) // b
def get_dtype_packing(dtype):
if dtype == jnp.float32:
return 1
if dtype == jnp.bfloat16:
return 2
if dtype == jnp.int8:
return 4
if dtype == jnp.int4:
return 8
raise ValueError(f"Not implemented: unsupported {dtype=}")
def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype):
q_packing = get_dtype_packing(q_dtype)
kv_packing = get_dtype_packing(kv_dtype)
def can_be_xla_fully_tiled(x, packing):
if x % packing != 0:
return False
x //= packing
return x in (1, 2, 4, 8) or x % 8 == 0
# TODO(jevinjiang): support unaligned number of heads!
if not can_be_xla_fully_tiled(num_kv_heads, kv_packing):
raise ValueError(
f"Not implemented: {num_kv_heads=} can not be XLA fully tiled."
)
assert num_q_heads % num_kv_heads == 0
ratio = num_q_heads // num_kv_heads
# TODO(jevinjiang): we can choose smaller tiling for packed type if large
# second minor tiling is not on.
max_kv_tiling = 8 * kv_packing
min_kv_heads = (
max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads
)
min_q_heads = min_kv_heads * ratio
if can_be_xla_fully_tiled(min_q_heads, q_packing):
return min_q_heads, min_kv_heads
return num_q_heads, num_kv_heads
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"mask_value",
"num_kv_pages_per_block",
"num_queries_per_block",
"vmem_limit_bytes",
],
)
def ragged_paged_attention(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
# TODO(jevinjiang): create a write_to_kv_cache kernel!
k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
*,
sm_scale: float = 1.0,
mask_value: float = DEFAULT_MASK_VALUE,
num_kv_pages_per_block: int = 16,
num_queries_per_block: int = 128,
vmem_limit_bytes: int | None = None,
):
"""Ragged paged attention that supports mixed prefill and decode.
Args:
q: concatenated all sequences' queries.
k_pages: paged K cache. Normally in HBM.
v_pages: paged V cache. Normally in HBM.
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
page_indices: the first index indicates which page to use in the kv cache
for each sequence. Only the first num_seqs values are valid.
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
kv_lens, only the first num_seqs+1 values are valid.
num_seqs: the dynamic number of sequences.
sm_scale: the softmax scale which will be applied to the Q@K^T.
mask_value: mask value for causal mask.
num_kv_pages_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
num_queries_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
vmem_limit_bytes: the vmem limit for the pallas kernel.
Returns:
The output of the attention.
"""
check_inputs_shapes(
q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs
)
_, num_q_heads, head_dim = q.shape
_, page_size, num_kv_heads, _ = k_pages.shape
num_q_per_blk = num_queries_per_block
num_kv_pages_per_blk = num_kv_pages_per_block
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk)
num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk(
num_q_heads, num_kv_heads, q.dtype, k_pages.dtype
)
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
num_heads_blks = num_q_heads // num_q_heads_per_blk
grid = (num_heads_blks, num_q_blks)
def q_index_map(heads_blk_idx, q_blk_idx, *_):
return (q_blk_idx, heads_blk_idx, 0)
q_block_spec = pl.BlockSpec(
(num_q_per_blk, num_q_heads_per_blk, head_dim),
q_index_map,
)
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = q_block_spec
lm_scratch = pltpu.VMEM(
# TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
# unaligned slicing!
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
jnp.float32,
)
double_buf_scratch = pltpu.VMEM(
(
2, # For double buffering during DMA copies.
num_kv_pages_per_blk,
page_size,
num_kv_heads_per_blk,
head_dim,
),
k_pages.dtype,
)
scratch_shapes = [
double_buf_scratch, # k_bufs
double_buf_scratch, # v_bufs
pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem]
lm_scratch, # l_ref
lm_scratch, # m_ref
]
scalar_prefetches = (
kv_lens,
page_indices,
cu_q_lens,
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
num_seqs,
)
kernel = pl.pallas_call(
functools.partial(
ragged_paged_attention_kernel,
sm_scale=sm_scale,
mask_value=mask_value,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
scratch_shapes=scratch_shapes,
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=(
"arbitrary",
"arbitrary",
),
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32),
name="ragged_paged_attention_kernel",
)
# TODO(jevinjiang): Use f32 acc scratch for output! So we only need
# to transfer output with desired dtype back to HBM.
return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype)

View File

@ -55,6 +55,57 @@ for prim in it.chain(
roofline.register_standard_roofline(prim)
def _unary_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
unfused_flops=x.size,
unfused_hbm_bytes=(
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
),
)
roofline.register_roofline(lax.abs_p)(_unary_p_roofline)
roofline.register_roofline(lax.acos_p)(_unary_p_roofline)
roofline.register_roofline(lax.asin_p)(_unary_p_roofline)
roofline.register_roofline(lax.atan_p)(_unary_p_roofline)
roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.ceil_p)(_unary_p_roofline)
roofline.register_roofline(lax.conj_p)(_unary_p_roofline)
roofline.register_roofline(lax.cos_p)(_unary_p_roofline)
roofline.register_roofline(lax.cosh_p)(_unary_p_roofline)
roofline.register_roofline(lax.exp_p)(_unary_p_roofline)
roofline.register_roofline(lax.expm1_p)(_unary_p_roofline)
roofline.register_roofline(lax.floor_p)(_unary_p_roofline)
roofline.register_roofline(lax.imag_p)(_unary_p_roofline)
roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline)
roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline)
roofline.register_roofline(lax.log_p)(_unary_p_roofline)
roofline.register_roofline(lax.log1p_p)(_unary_p_roofline)
roofline.register_roofline(lax.logistic_p)(_unary_p_roofline)
roofline.register_roofline(lax.neg_p)(_unary_p_roofline)
roofline.register_roofline(lax.not_p)(_unary_p_roofline)
roofline.register_roofline(lax.real_p)(_unary_p_roofline)
roofline.register_roofline(lax.round_p)(_unary_p_roofline)
roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.sign_p)(_unary_p_roofline)
roofline.register_roofline(lax.sin_p)(_unary_p_roofline)
roofline.register_roofline(lax.sinh_p)(_unary_p_roofline)
roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.square_p)(_unary_p_roofline)
roofline.register_roofline(lax.tan_p)(_unary_p_roofline)
roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline)
roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline)
roofline.register_roofline(special.digamma_p)(_unary_p_roofline)
roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline)
roofline.register_roofline(special.erf_p)(_unary_p_roofline)
roofline.register_roofline(special.erfc_p)(_unary_p_roofline)
roofline.register_roofline(special.lgamma_p)(_unary_p_roofline)
def _binary_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,

View File

@ -544,6 +544,8 @@ def _shard_map_staging(
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
# TODO add underscore version, for direct-linearize to consume
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
in_avals_, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names) - auto
)
manual_axes = frozenset(mesh.axis_names) - auto
new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with _extend_axis_env(mesh, auto):
out_nodes_, tokens_out = mlir.call_lowering(
@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool,
def _match(mesh, check_rep, pspec, x):
src = P(mesh.axis_names)
# TODO put back (?) needed for rep checking in eager? for now test rewrite
return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x)
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "auto", "check", "context_mesh")
mesh: Mesh
auto: frozenset[AxisName]
check: bool
context_mesh: AbstractMesh
@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace):
if isinstance(val, ShardMapTracer):
return val.val, val.rep
elif isinstance(val, Tracer):
raise Exception("Shouldn't have any non-shard_map tracers")
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
return val_, None
@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
f.debug_info)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
res_names = _all_newly_manual_mesh_names(mesh, auto, trace)
@as_hashable_function(closure=(linearize_outs_thunk))
@as_hashable_function(closure=linearize_outs_thunk)
def primal_out_names_thunk():
residual_avals, _, _ = linearize_outs_thunk()
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
out_names = out_names_thunk()
# This is incorrect so we set `check_rep=False` as we do in the JVP rule.
return (*({0: all_names} for _ in residual_avals), *out_names)
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
# This is incorrect so we set `check_rep=False` in the tangent (as in JVP).
return (*({0: res_names} for _ in range(num_res_out)), *out_names)
primal_params = dict(
mesh=mesh, in_names=in_names,
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
all_primal_results = shard_map_p.bind_with_trace(
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals]
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
trace.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
non_fwd_res = all_primal_results[:num_res_out]
primals_out = all_primal_results[num_res_out:]
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)]
with core.extend_axis_env_nd(mesh.shape.items()):
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
out_names = out_names_thunk()
new_in_names = (*({0: all_names} for _ in residual_avals),
residual_names = [in_names[f1] if f1 is not None else
out_names[f2] if f2 is not None else
{0: res_names} for f1, f2 in zip(in_fwd, out_fwd)]
new_in_names = (*residual_names, *({} for _ in range(len(env))),
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz)
@as_hashable_function(closure=(new_out_names))
def tangent_out_names_thunk():
return new_out_names
@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
out_names_thunk=tangent_out_names_thunk, check_rep=False,
rewrite=rewrite, auto=auto)
# TODO TODO don't round-trip
def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, (), *args)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), tangent_params)
*residuals, *env, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize
@lu.transformation2
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
ans = f(*args, **kwargs)
residual_avals, _, _ = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = ans[:num_residuals]
primals = ans[num_residuals:]
residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals)
return residuals + primals
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
residuals = ans[:num_res_out]
primals = ans[num_res_out:]
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals]
return *residuals, *primals
@lu.transformation2
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule(
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym()
params_known, params_staged, all_names = _pe_custom_params(
params_known, params_staged, res_names = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval))
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
# prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh']
auto = params_known['auto']
all_names = _all_newly_manual_mesh_names(mesh, auto)
res_names_ = _all_newly_manual_mesh_names(mesh, auto)
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: all_names}] * sum(which)
out_names_known = out_names_known + [{0: res_names_}] * sum(which)
new_params_known = dict(params_known, in_names=tuple(in_names_known),
out_names=tuple(out_names_known))
@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
res_names = [in_names_known[f1] if f1 is not None else
out_names_known[f2] if f2 is not None else
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
{0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)]
in_names_staged = res_names + in_names_staged
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
out_names=tuple(out_names_staged), check_rep=False)
return new_params_known, new_params_staged, all_names
return new_params_known, new_params_staged, res_names_
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(
@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd(
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto)
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_newly_manual_mesh_names(
mesh: Mesh, auto: frozenset[AxisName], trace=None
) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
axis_sizes = axis_env.axis_sizes
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto and name not in axis_sizes)
if not (ctx_mesh := get_abstract_mesh()).empty:
del mesh
already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ()))
return tuple(name for name in ctx_mesh.axis_names
if name not in auto | already_manual_names)
else:
# TODO(mattjj): remove this mechanism when we revise mesh scopes
axis_env = core.get_axis_env()
vmap_spmd_names = set(axis_env.spmd_axis_names)
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
return tuple(name for name in mesh.axis_names
if name not in auto | vmap_spmd_names | already_manual_names)
# DCE

View File

@ -19,8 +19,18 @@ import math
import jax
from jax._src import core
from jax._src import ffi
from jax._src import util
from jax._src.typing import Array
from jax._src.lib import gpu_sparse
if hasattr(gpu_sparse, "registrations"):
for platform, targets in gpu_sparse.registrations().items():
for name, value, api_version in targets:
ffi.register_ffi_target(
name, value, platform=platform, api_version=api_version
)
class JAXSparse(util.StrictABC):

View File

@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars, name,
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")
@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,

View File

@ -17,6 +17,7 @@
from jax._src.lax.lax import (
DotDimensionNumbers as DotDimensionNumbers,
RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
Precision as Precision,
PrecisionLike as PrecisionLike,
DotAlgorithm as DotAlgorithm,
@ -158,6 +159,7 @@ from jax._src.lax.lax import (
pow as pow,
pow_p as pow_p,
ragged_dot as ragged_dot,
ragged_dot_general as ragged_dot_general,
real as real,
real_p as real_p,
reciprocal as reciprocal,

Some files were not shown because too many files have changed in this diff Show More