diff --git a/.bazelrc b/.bazelrc index e199ffa4a..2fb51664b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -54,6 +54,12 @@ build:macos --apple_platform_type=macos build:macos --linkopt=-Wl,-undefined,dynamic_lookup build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup +# Use cc toolchains from apple_support for Apple builds. +# https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup +build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain +build:macos --crosstool_top=@local_config_apple_cc//:toolchain +build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain + # Windows has a relatively short command line limit, which JAX has begun to hit. # See https://docs.bazel.build/versions/main/windows.html build:windows --features=compiler_param_file diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 3600ad134..2b97c5a05 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'd982a6245400295477f5da5afa1c4a2a5e641ea4' # Latest commit as of 2025-01-30 + ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 3d153f6bc..e64f81809 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,11 +29,6 @@ on: type: string required: true default: "0" - install-jax-current-commit: - description: "Should the 'jax' package be installed from the current commit?" - type: string - required: true - default: "1" gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -62,7 +57,6 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" - JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -88,7 +82,7 @@ jobs: # `*-cp-cp-*`, while free-threaded wheels use # `*-cp-cpt-*`. echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download jaxlib wheel from GCS (non-Windows runs) + - name: Download wheels from GCS (non-Windows runs) id: download-wheel-artifacts-nw # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in the step below so that we can print a more @@ -96,14 +90,10 @@ jobs: continue-on-error: true if: ${{ !contains(inputs.runner, 'windows-x86') }} run: | - mkdir -p $(pwd)/dist && + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - # Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - fi - - name: Download jaxlib wheel from GCS (Windows runs) + - name: Download wheels from GCS (Windows runs) id: download-wheel-artifacts-w # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in step below so that we can print a more @@ -115,12 +105,8 @@ jobs: mkdir dist @REM Use `call` so that we can run sequential gsutil commands on Windows @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 + call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ - - @REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if not "${{ inputs.install-jax-current-commit }}"=="1" ( - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ - ) - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index fde109f9e..3dbd5bb0a 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -34,11 +34,6 @@ on: type: string required: true default: "0" - install-jax-current-commit: - description: "Should the 'jax' package be installed from the current commit?" - type: string - required: true - default: "1" gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -66,7 +61,6 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" - JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -86,7 +80,7 @@ jobs: # `*-cp-cp-*`, while free-threaded wheels use # `*-cp-cpt-*`. echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download the wheel artifacts from GCS + - name: Download wheels from GCS id: download-wheel-artifacts # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in the next step so that we can print a more @@ -94,14 +88,10 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - # Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 71542ea5d..7b713b2da 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -26,8 +26,6 @@ race_top:PyMember_GetOne # https://github.com/python/cpython/issues/129547 race:type_get_annotations -# https://github.com/python/cpython/issues/130547 -race:split_keys_entry_added # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx @@ -64,3 +62,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130571 # race:_PyObject_GetMethod + +# https://github.com/python/cpython/issues/130547 +# race:split_keys_entry_added diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 1de765df0..2940d3dd2 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -35,7 +35,7 @@ jobs: apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ - libffi-dev liblzma-dev + libffi-dev liblzma-dev file zip - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index fe1304c14..5c818bf56 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -27,6 +27,16 @@ concurrency: cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: + build-jax-artifact: + uses: ./.github/workflows/build_artifacts.yml + with: + # Note that since jax is a pure python package, the runner OS and Python values do not + # matter. In addition, cloning main XLA also has no effect. + runner: "linux-x86-n2-16" + artifact: "jax" + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + build-jaxlib-artifact: uses: ./.github/workflows/build_artifacts.yml strategy: @@ -66,7 +76,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: build-jaxlib-artifact + needs: [build-jax-artifact, build-jaxlib-artifact] uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -80,7 +90,6 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - install-jax-current-commit: 1 gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} run-pytest-cuda: @@ -88,7 +97,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -111,7 +120,6 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - install-jax-current-commit: 1 # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 33d62db4f..b88b000e4 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -40,9 +40,6 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - # Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the - # GCS bucket. - install-jax-current-commit: 0 gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -61,7 +58,4 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - # Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the - # GCS bucket. - install-jax-current-commit: 0 gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel index 617e39e73..441f689e3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", "jax_wheel", diff --git a/CHANGELOG.md b/CHANGELOG.md index d5a01b780..fd65ae848 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. +## jax 0.5.2 (Mar 4, 2025) + +Patch release of 0.5.1 + +* Bug fixes + * Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1 + ## jax 0.5.1 (Feb 24, 2025) * New Features @@ -54,6 +61,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail. + * In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`, + non-arraylike inputs (such as lists, tuples, etc.) are now deprecated. * Bug fixes * TPU runtime startup and shutdown time should be significantly improved on @@ -169,8 +178,6 @@ to signify this. This is a patch release of jax 0.4.36. Only "jax" was released at this version. -## jax 0.4.37 - * Bug fixes * Fixed a bug where `jit` would error if an argument was named `f` (#25329). * Fix a bug that will throw `index out of range` error in diff --git a/WORKSPACE b/WORKSPACE index 8c4f49ecf..129488281 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ jax_python_wheel_repository( ) load( - "@tsl//third_party/py:python_wheel.bzl", + "@xla//third_party/py:python_wheel.bzl", "python_wheel_version_suffix_repository", ) python_wheel_version_suffix_repository( @@ -78,7 +78,7 @@ python_wheel_version_suffix_repository( ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) @@ -90,7 +90,7 @@ load( "CUDNN_REDISTRIBUTIONS", ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) @@ -104,21 +104,21 @@ cudnn_redist_init_repository( ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( - "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( - "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "@xla//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 5e2555769..c3be27f4a 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -475,7 +475,7 @@ def bench_pjit_check_aval_sharding(state): aval = jax.core.ShapedArray((8, 2), np.int32) while state: - pjit_check_aval_sharding([s] * 100, [aval] * 100, None, 'benchmark', False) + pjit_check_aval_sharding([s] * 100, [aval] * 100, [''] * 100, 'benchmark', False) @google_benchmark.register diff --git a/build/build.py b/build/build.py index dc6bc30cc..5bcbdd862 100755 --- a/build/build.py +++ b/build/build.py @@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = { "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", } +# Dictionary with the new wheel build rule. Note that when JAX migrates to the +# new wheel build rule fully, the build CLI will switch to the new wheel build +# rule as the default. +WHEEL_BUILD_TARGET_DICT_NEW = { + "jax": "//:jax_wheel", + "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", +} def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" @@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--use_new_wheel_build_rule", + action="store_true", + help= + """ + Whether to use the new wheel build rule. Temporary flag and will be + removed once JAX migrates to the new wheel build rule fully. + """, + ) + parser.add_argument( "--editable", action="store_true", @@ -386,7 +407,10 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - bazel_command_base.append("run") + if not args.use_new_wheel_build_rule or args.command == "requirements_update": + bazel_command_base.append("run") + else: + bazel_command_base.append("build") if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version @@ -592,13 +616,19 @@ async def main(): wheel_build_command_base.append("--config=cuda_libraries_from_stubs") with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list()) + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) if not jax_configure_options: logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") sys.exit(1) f.write(jax_configure_options) logging.info("Bazel options written to .jax_configure.bazelrc") + if args.use_new_wheel_build_rule: + logging.info("Using new wheel build rule") + wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW + else: + wheel_build_targets = WHEEL_BUILD_TARGET_DICT + if args.configure_only: logging.info("--configure_only is set so not running any Bazel commands.") else: @@ -611,7 +641,7 @@ async def main(): if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel - if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + if wheel not in wheel_build_targets.keys(): logging.error( "Incorrect wheel name provided, valid choices are jaxlib," " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," @@ -629,32 +659,33 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = WHEEL_BUILD_TARGET_DICT[wheel] + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) - wheel_build_command.append("--") + if not args.use_new_wheel_build_rule: + wheel_build_command.append("--") - if args.editable: - logger.info("Building an editable build") - output_path = os.path.join(output_path, wheel) - wheel_build_command.append("--editable") + if args.editable: + logger.info("Building an editable build") + output_path = os.path.join(output_path, wheel) + wheel_build_command.append("--editable") - wheel_build_command.append(f'--output_path="{output_path}"') - wheel_build_command.append(f"--cpu={target_cpu}") + wheel_build_command.append(f'--output_path="{output_path}"') + wheel_build_command.append(f"--cpu={target_cpu}") - if "cuda" in wheel: - wheel_build_command.append("--enable-cuda=True") - if args.cuda_version: - cuda_major_version = args.cuda_version.split(".")[0] - else: - cuda_major_version = args.cuda_major_version - wheel_build_command.append(f"--platform_version={cuda_major_version}") + if "cuda" in wheel: + wheel_build_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + wheel_build_command.append(f"--platform_version={cuda_major_version}") - if "rocm" in wheel: - wheel_build_command.append("--enable-rocm=True") - wheel_build_command.append(f"--platform_version={args.rocm_version}") + if "rocm" in wheel: + wheel_build_command.append("--enable-rocm=True") + wheel_build_command.append(f"--platform_version={args.rocm_version}") - wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") + wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) # Exit with error if any wheel build fails. diff --git a/build/requirements.in b/build/requirements.in index e122aaa4a..d4e13d943 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -18,5 +18,4 @@ ml_dtypes>=0.4.0 opt_einsum zstandard etils[epath] -# TODO(ybaturina): remove setuptools version -setuptools<71.0.0 +setuptools diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index ccffa247f..290c7e732 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -634,9 +634,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 7f3ee61ff..f73065950 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -623,9 +623,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index bf22c3623..feebc33dc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -623,9 +623,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 9fa78c062..0a32888f6 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -747,9 +747,9 @@ zstandard==0.23.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==70.3.0 \ - --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ - --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 19d713532..3b36900c0 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -12,8 +12,7 @@ portpicker; python_version<"3.13" pytest-xdist wheel rich -# TODO(ybaturina): remove setuptools version -setuptools<71.0.0 +setuptools # matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement # below. matplotlib~=3.8.4; python_version=="3.10" diff --git a/build/tools/utils.py b/build/tools/utils.py index e91b2d424..7e3751698 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str): return major_version -def get_jax_configure_bazel_options(bazel_command: list[str]): +def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool): """Returns the bazel options to be written to .jax_configure.bazelrc.""" # Get the index of the "run" parameter. Build options will come after "run" so - # we find the index of "run" and filter everything after it. - start = bazel_command.index("run") + # we find the index of "run" and filter everything after it. If we are using + # the new wheel build rule, we will find the index of "build" instead. + if use_new_wheel_build_rule: + start = bazel_command.index("build") + else: + start = bazel_command.index("run") jax_configure_bazel_options = "" try: for i in range(start + 1, len(bazel_command)): diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 3cc1fa0c5..84b8d35a2 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then arch="amd64" fi +# Determine the artifact tag flags based on the artifact type. A release +# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is +# tagged with the release version and a nightly suffix that contains the +# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with +# the git commit hash of the HEAD of the current branch and the date of the +# commit (e.g. 0.5.1.dev20250128+3e75e20c7). +if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release" +elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then + current_date=$(date +%Y%m%d) + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly" +elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)" +else + echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default" + exit 1 +fi + if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then + # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" + # flags in the .bazelrc depending upon the platform we are building for. + bazelrc_config="${os}_${arch}" - # Build the jax artifact - if [[ "$artifact" == "jax" ]]; then - python -m build --outdir $JAXCI_OUTPUT_DIR + # On platforms with no RBE support, we can use the Bazel remote cache. Set + # it to be empty by default to avoid unbound variable errors. + bazel_remote_cache="" + + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then + bazelrc_config="rbe_${bazelrc_config}" else + bazelrc_config="ci_${bazelrc_config}" - # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" - # flags in the .bazelrc depending upon the platform we are building for. - bazelrc_config="${os}_${arch}" - - # On platforms with no RBE support, we can use the Bazel remote cache. Set - # it to be empty by default to avoid unbound variable errors. - bazel_remote_cache="" - - if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then - bazelrc_config="rbe_${bazelrc_config}" + # Set remote cache flags. Pushes to the cache bucket is limited to JAX's + # CI system. + if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then + bazel_remote_cache="--bazel_options=--config=public_cache_push" else - bazelrc_config="ci_${bazelrc_config}" - - # Set remote cache flags. Pushes to the cache bucket is limited to JAX's - # CI system. - if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then - bazel_remote_cache="--bazel_options=--config=public_cache_push" - else - bazel_remote_cache="--bazel_options=--config=public_cache" - fi + bazel_remote_cache="--bazel_options=--config=public_cache" fi + fi - # Use the "_cuda" configs when building the CUDA artifacts. - if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then - bazelrc_config="${bazelrc_config}_cuda" - fi + # Use the "_cuda" configs when building the CUDA artifacts. + if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then + bazelrc_config="${bazelrc_config}_cuda" + fi - # Build the artifact. + # Build the artifact. + python build/build.py build --wheels="$artifact" \ + --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ + --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $artifact_tag_flags + + # If building release artifacts, we also build a release candidate ("rc") + # tagged wheel. + if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ - --verbose --detailed_timestamped_log + --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" + fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we - # run `auditwheel show` to verify manylinux compliance. - if [[ "$os" == "linux" ]]; then - ./ci/utilities/run_auditwheel.sh - fi + # Move the built artifacts from the Bazel cache directory to the output + # directory. + if [[ "$artifact" == "jax" ]]; then + mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" + mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" + else + mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" + fi + # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we + # run `auditwheel show` to verify manylinux compliance. + if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then + ./ci/utilities/run_auditwheel.sh fi else diff --git a/ci/envs/default.env b/ci/envs/default.env index 72646113e..7a2448944 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} # flag is enabled only for CI builds. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} +# Type of artifacts to build. Valid values are "default", "release", "nightly". +# This affects the wheel naming/tag. +export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} + +# When building release artifacts, we build a release candidate wheel ("rc" +# tagged wheel) in addition to the release wheel. This environment variable +# sets the version of the release candidate ("RC") artifact to build. +export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} + # ############################################################################# # Test script specific environment variables. # ############################################################################# @@ -65,9 +74,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. -export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} - -# Installs the JAX package in editable mode at the current commit. Enabled by -# default. Nightly/Release builds disable this flag in the Github action -# workflow files. -export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"} \ No newline at end of file +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f0e245e14..41274b95f 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -19,7 +19,7 @@ # avoid using the Windows version of `find` on Msys. WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) -if [[ -z "$WHEELS" ]]; then +if [[ -z "${WHEELS[@]}" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" exit 1 fi @@ -38,10 +38,4 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") else "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" -fi - -if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then - echo "Installing the JAX package at the current commit..." - # Install JAX package at the current commit. - "$JAXCI_PYTHON" -m uv pip install . fi \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 8e4039414..114acf247 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -98,3 +98,6 @@ function retry { # Retry "bazel --version" 3 times to avoid flakiness when downloading bazel. retry "bazel --version" + +# Create the output directory if it doesn't exist. +mkdir -p "$JAXCI_OUTPUT_DIR" \ No newline at end of file diff --git a/docs/about.md b/docs/about.md index c4bc93140..58e170384 100644 --- a/docs/about.md +++ b/docs/about.md @@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly, and elsewhere. At the heart of the project is the [JAX -core](http://github.com/google/jax) library, which focuses on the +core](http://github.com/jax-ml/jax) library, which focuses on the fundamentals of machine learning and numerical computing, at scale. When [developing](#development) the core, we want to maintain agility diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 73ac02628..85580120c 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -91,8 +91,8 @@ def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) -# Prints: x: 1.0 -# x: 0.0 +# Prints: x: 0.0 +# x: 1.0 # OR # Prints: x: 1.0 # x: 0.0 diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index fe84fc0d7..30c626bec 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: # and then use `updates` instead of `grad` to actually update the params. # (And we'd include `new_optimizer_state` in the output, naturally.) - new_params = jax.tree_map( + new_params = jax.tree.map( lambda param, g: param - g * LEARNING_RATE, params, grad) return new_params diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f1c0078cd..c2868cf7c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -69,7 +69,7 @@ def dots_saveable(prim, *_, **__) -> bool: lax_convolution.conv_general_dilated_p} checkpoint_dots = dots_saveable -def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool: +def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool: # This is a useful heuristic for transformers. if prim is lax_internal.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] @@ -160,8 +160,8 @@ checkpoint_policies = types.SimpleNamespace( nothing_saveable=nothing_saveable, dots_saveable=dots_saveable, checkpoint_dots=dots_saveable, - dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable, - checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable, + dots_with_no_batch_dims_saveable=dots_with_no_batch_dims_saveable, + checkpoint_dots_with_no_batch_dims=dots_with_no_batch_dims_saveable, offload_dot_with_no_batch_dims=offload_dot_with_no_batch_dims, save_anything_except_these_names=save_anything_except_these_names, save_any_names_but_these=save_any_names_but_these, @@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args): raise ValueError("the `static_argnums` argument to `jax.checkpoint` / " "`jax.remat` can only take integer values greater than or " "equal to `-len(args)` and less than `len(args)`, but got " - f"{static_argnums}") + f"{static_argnums}, while `len(args)` = {len(args)}") if not static_argnums: return fun, args diff --git a/jax/_src/api.py b/jax/_src/api.py index baf2af6e9..4b14d8096 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1094,6 +1094,9 @@ def _mapped_axis_size(fn, tree, vals, dims, name): return f"args{keystr(key_path)}" # args is a tuple, so key_path[0].idx is the index into args. i = key_path[0].idx + # This can happen with star arguments (*args) + if i >= len(signature_parameters): + return f"args{keystr(key_path)}" res = f"argument {signature_parameters[i]}" if len(key_path) > 1: res += keystr(key_path[1:]) @@ -1135,8 +1138,8 @@ def pmap( fun: Callable, axis_name: AxisName | None = None, *, - in_axes=0, - out_axes=0, + in_axes: int | None | Sequence[Any] = 0, + out_axes: Any = 0, static_broadcasted_argnums: int | Iterable[int] = (), devices: Sequence[xc.Device] | None = None, # noqa: F811 backend: str | None = None, @@ -2002,8 +2005,8 @@ def vjp( raise NotImplementedError("reduce_axes argument to vjp is deprecated") del reduce_axes check_callable(fun) - wrapped_fun = lu.wrap_init(fun, - debug_info=debug_info("vjp", fun, primals, {})) + wrapped_fun = lu.wrap_init( + fun, debug_info=debug_info("vjp", fun, primals, {})) return _vjp(wrapped_fun, *primals, has_aux=has_aux) def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1fd371034..a42141b96 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -382,6 +382,9 @@ def is_hashable(arg): return False +SENTINEL = object() + + def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of @@ -389,7 +392,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # and return the flattened result # TODO(mattjj,phawkins): improve this implementation proxy = object() - dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) + dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves) axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: @@ -564,8 +567,9 @@ def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]: passed_kwargs = [k for k in ba.kwargs if k in kwargs] if passed_kwargs: raise TypeError( - f"keyword arguments ({passed_kwargs}) could not be resolved to " - "positions") + "The following keyword arguments could not be resolved to positions: " + f"{', '.join(passed_kwargs)}" + ) return ba.args diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5a6561afa..74ea53714 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, inline, keep_unused, + donated_invars, ctx_mesh, name, inline, keep_unused, compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) @@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, out_shardings=new_out_shardings, in_layouts=new_in_layouts, out_layouts=new_out_layouts, - resource_env=resource_env, donated_invars=new_donated_invars, + ctx_mesh=ctx_mesh, name=name, inline=inline, keep_unused=keep_unused, diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 82219886b..0539e4253 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,6 +15,7 @@ import datetime import os import re +import warnings from jax import version from jax._src import config from jax._src import hardware_utils @@ -72,7 +73,19 @@ def cloud_tpu_init() -> None: # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() - num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] + num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if ( + tpu_id is not None + and tpu_id >= hardware_utils.TpuVersion.v5e + and not hardware_utils.transparent_hugepages_enabled() + ): + warnings.warn( + 'Transparent hugepages are not enabled. TPU runtime startup and' + ' shutdown time should be significantly improved on TPU v5e and newer.' + ' If not already set, you may need to enable transparent hugepages in' + ' your VM image (sudo sh -c "echo always >' + ' /sys/kernel/mm/transparent_hugepage/enabled")' + ) if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py index 480457871..e0241d54b 100644 --- a/jax/_src/compilation_cache_interface.py +++ b/jax/_src/compilation_cache_interface.py @@ -15,8 +15,8 @@ from __future__ import annotations import abc +import pathlib -from jax._src import path as pathlib from jax._src import util diff --git a/jax/_src/core.py b/jax/_src/core.py index 210f8cb68..9d8edeb8b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1815,6 +1815,7 @@ def _make_lengths_same(sharding, ndim): if ndim > len(sharding.spec): return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) if ndim < len(sharding.spec): + assert all(s is None for s in sharding.spec[ndim:]) return sharding.with_spec(sharding.spec[:ndim]) assert False, "unreachable" @@ -1840,6 +1841,8 @@ def _maybe_modify_sharding(sharding, ndim): return sharding if sharding.mesh._are_all_axes_explicit: + if ndim > len(sharding.spec): + return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) return sharding out = sharding.with_spec(modify_spec_for_auto_manual( @@ -1849,9 +1852,22 @@ def _maybe_modify_sharding(sharding, ndim): out = _make_lengths_same(out, ndim) return out +def _check_divisibility(sharding, shape): + mesh = sharding.mesh + for dim, (spec, sh) in enumerate(zip(sharding.spec, shape)): + if spec is None: + continue + spec = spec if isinstance(spec, tuple) else (spec,) + size = math.prod(mesh.shape[s] for s in spec) + _, remainder = divmod(sh, size) + if remainder != 0: + raise ValueError( + f"Sharding spec {spec} implies that array axis {dim} is partitioned" + f" {size} times, but does not evenly divide the dimension size {sh}." + f" Got shape: {shape} and sharding {sharding}") @cache(max_size=4096, trace_context_in_key=True) -def get_sharding(sharding, ndim): +def get_sharding(sharding, shape): """Modifies and checks the sharding. Some modifications/checks include: @@ -1860,6 +1876,7 @@ def get_sharding(sharding, ndim): * Checking for len(spec)-ndim match * Checking if the mesh is an AbstractMesh. """ + ndim = len(shape) if sharding is None: return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim)) @@ -1871,6 +1888,7 @@ def get_sharding(sharding, ndim): if not isinstance(out_s.mesh, mesh_lib.AbstractMesh): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") + _check_divisibility(out_s, shape) return out_s @@ -1882,7 +1900,7 @@ class ShapedArray(UnshapedArray): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type - self.sharding = get_sharding(sharding, len(self.shape)) + self.sharding = get_sharding(sharding, self.shape) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -2489,8 +2507,8 @@ class MapPrimitive(Primitive): def get_bind_params(self, params): new_params = dict(params) jaxpr: Jaxpr = new_params.pop('call_jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, - debug_info=jaxpr.debug_info), jaxpr, ()) + subfun = lu.hashable_partial( + lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) axes = new_params.pop('out_axes') new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 1b386cc4d..c7e7c83f3 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -333,7 +333,7 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training, is_packed, + query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index a6ca6479c..338074837 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -143,7 +143,15 @@ class custom_vmap: def __call__(self, *args, **kwargs): debug_fun = api_util.debug_info("custom_vmap fun", self.fun, args, kwargs) - args = api_util.resolve_kwargs(self.fun, args, kwargs) + try: + args = api_util.resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_vmap-decorated function " + f"{debug_fun.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + if not self.vmap_rule: raise AttributeError( f"No batching rule defined for custom_vmap function {debug_fun.func_name} " diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 9166965b5..d336c969a 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -133,7 +133,15 @@ class custom_dce: debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule, args, {}, static_argnums=self.static_argnums) - args = api_util.resolve_kwargs(self.fun, args, kwargs) + try: + args = api_util.resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_dce-decorated function " + f"{debug.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + if self.static_argnums: static_argnums = set(self.static_argnums) for i in static_argnums: diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 1cea84110..32856106a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -250,7 +250,15 @@ class custom_jvp(Generic[ReturnValue]): msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) - args = resolve_kwargs(self.fun, args, kwargs) + try: + args = resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_jvp-decorated function " + f"{primal_name} could not be resolved to positional-only arguments. " + f"Binding failed with the error:\n{e}" + ) from e + if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -634,7 +642,16 @@ class custom_vjp(Generic[ReturnValue]): if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp." raise AttributeError(msg) - args = resolve_kwargs(self.fun, args, kwargs) + + try: + args = resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_vjp-decorated function " + f"{debug_fun.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs, static_argnums=self.nondiff_argnums) # TODO(necula): figure out how to construct the debug_bwd args @@ -1238,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]: def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index b835c4c83..658a6f7a2 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( *tiled_args ) - if closed_jaxpr.out_avals != tiled_results: + if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != + [(t.shape, t.dtype) for t in tiled_results]): raise ValueError( "Mismatch in result shapes. %s vs %s" % (repr(closed_jaxpr.out_avals), repr(tiled_results)) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 808d129ba..01500c008 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz) +# fp4 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float4_e2m1fn: type[np.generic] | None = None + +_float4_e2m1fn_dtype: np.dtype | None = None + def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" typ = np.dtype(dtype).type @@ -144,6 +150,8 @@ _float8_dtypes = [ _float8_e5m2fnuz_dtype, ] +_float4_dtypes: list[np.dtype] = [] + # TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 if hasattr(ml_dtypes, "float8_e4m3"): float8_e4m3 = ml_dtypes.float8_e4m3 @@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"): _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) +if hasattr(ml_dtypes, "float4_e2m1fn"): + float4_e2m1fn = ml_dtypes.float4_e2m1fn + _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) + _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) + _float4_dtypes.insert(0, _float4_e2m1fn_dtype) # 2-bit integer support int2: type[np.generic] | None = None @@ -716,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "promotion path. To avoid unintended promotion, 8-bit floats do not support " "implicit promotion. If you'd like your inputs to be promoted to another type, " "you can do so explicitly using e.g. x.astype('float32')") + elif any(n in _float4_dtypes for n in nodes): + msg = ( + f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " + "promotion path. To avoid unintended promotion, 4-bit floats do not support " + "implicit promotion. If you'd like your inputs to be promoted to another type, " + "you can do so explicitly using e.g. x.astype('float32')") elif any(n in _intn_dtypes for n in nodes): msg = ( f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 1c009224d..11e65a7dd 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -33,12 +33,11 @@ class JaxValueError(ValueError): """Exception raised for failed runtime error checks in JAX.""" +#: The default error code for no error. +#: +#: This value is chosen because we can use `jnp.min()` to obtain the +#: first error when performing reductions. _NO_ERROR = jnp.iinfo(jnp.uint32).max -"""The default error code for no error. - -We choose this value because when performing reductions, we can use `min` to -obtain the smallest error code. -""" _error_list_lock = threading.Lock() @@ -62,7 +61,7 @@ def _initialize_error_code_ref() -> None: def set_error_if(pred: jax.Array, msg: str) -> None: - """Set error if pred is true. + """Set error if any element of pred is true. If the error is already set, the new error will be ignored. It will not override the existing error. @@ -74,7 +73,7 @@ def set_error_if(pred: jax.Array, msg: str) -> None: traceback = source_info_util.current().traceback assert traceback is not None with _error_list_lock: - new_error_code = len(_error_list) + new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) pred = pred.any() @@ -86,18 +85,26 @@ def set_error_if(pred: jax.Array, msg: str) -> None: def raise_if_error() -> None: - """Raise error if an error is set.""" - if _error_storage.ref is None: - return # if not initialized, do nothing + """Raise error if an error is set. + + This function should be called after the computation is finished. It should + not be called within a traced context, such as within a jitted function." + """ + if _error_storage.ref is None: # if not initialized, do nothing + return error_code = _error_storage.ref[...] + if isinstance(error_code, core.Tracer): + raise ValueError( + "raise_if_error() should not be called within a traced context, such as" + " within a jitted function." + ) if error_code == jnp.uint32(_NO_ERROR): return - try: - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) - finally: - _error_storage.ref[...] = jnp.uint32(_NO_ERROR) + _error_storage.ref[...] = jnp.uint32(_NO_ERROR) + + msg, traceback = _error_list[error_code] + exc = JaxValueError(msg) + traceback = traceback.as_python_traceback() + filtered_traceback = traceback_util.filter_traceback(traceback) + raise exc.with_traceback(filtered_traceback) diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index dd0ae3edc..7d3e342f1 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -75,6 +75,7 @@ enum DType: byte { f8_e5m2 = 20, f8_e5m2fnuz = 21, f8_e8m0fnu = 25, + f4_e2m1fn = 26, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 7707670f1..ac97c11d1 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 if dtypes._float8_e8m0fnu_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu +if dtypes._float4_e2m1fn_dtype is not None: + _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 69092cd7e..b1fc13333 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -62,6 +62,7 @@ class DType(object): f8_e5m2fnuz = 21 f0 = 22 f8_e8m0fnu = 25 + f4_e2m1fn = 26 class ShardingKind(object): diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 81ef07a71..84ad9edf9 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -12,25 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum import os import pathlib import glob _GOOGLE_PCI_VENDOR_ID = '0x1ae0' -_TPU_PCI_DEVICE_IDS = [ - # TPU v2, v3 - '0x0027', - # No public name (plc) - '0x0056', - # TPU v4 - '0x005e', - # TPU v5p - '0x0062', - # TPU v5e - '0x0063', - # TPU v6e - '0x006f', -] _NVIDIA_GPU_DEVICES = [ '/dev/nvidia0', @@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [ '/dev/dxg', # WSL2 ] + +class TpuVersion(enum.IntEnum): + # TPU v2, v3 + v2 = 0 + v3 = 1 + # No public name (plc) + plc = 2 + # TPU v4 + v4 = 3 + # TPU v5p + v5p = 4 + # TPU v5e + v5e = 5 + # TPU v6e + v6e = 6 + + +_TPU_PCI_DEVICE_IDS = { + '0x0027': TpuVersion.v3, + '0x0056': TpuVersion.plc, + '0x005e': TpuVersion.v4, + '0x0062': TpuVersion.v5p, + '0x0063': TpuVersion.v5e, + '0x006f': TpuVersion.v6e, +} + def num_available_tpu_chips_and_device_id(): """Returns the device id and number of TPU chips attached through PCI.""" num_chips = 0 - device_id = '' + tpu_version = None for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'): vendor_id = pathlib.Path(vendor_path).read_text().strip() if vendor_id != _GOOGLE_PCI_VENDOR_ID: @@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id(): device_path = os.path.join(os.path.dirname(vendor_path), 'device') device_id = pathlib.Path(device_path).read_text().strip() if device_id in _TPU_PCI_DEVICE_IDS: + tpu_version = _TPU_PCI_DEVICE_IDS[device_id] num_chips += 1 - return num_chips, device_id + return num_chips, tpu_version def has_visible_nvidia_gpu() -> bool: """True if there's a visible nvidia gpu available on device, False otherwise.""" return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES) + + +def transparent_hugepages_enabled() -> bool: + # See https://docs.kernel.org/admin-guide/mm/transhuge.html for more + # information about transparent huge pages. + path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled') + return path.exists() and path.read_text().strip() == '[always] madvise never' diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 37ad40d22..2f835ab83 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, - partition_list) + partition_list, subs_list2) zip = safe_zip map = safe_map @@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, *primals, **params): with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) + tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, tangent_trace.new_arg(get_aval(p).to_tangent_aval())) @@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) - residual_avals = map(get_aval, consts) if attrs_tracked: raise NotImplementedError("TODO: attrs") - _store.store((residual_avals, nzs_out, jaxpr)) - return tuple(consts) + tuple(out_primals) + which_env = [(isinstance(c, pe.DynamicJaxprTracer) and + getattr(c._trace, 'tag', None) is _tag) for c in consts] + jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) + res, env = partition_list(which_env, consts) + residual_avals = map(get_aval, res) + # Which residuals are just forwarded inputs? Check object id. + id_map = {id(p): i for i, p in enumerate(primals)} + in_fwd: list[int | None] = [id_map.get(id(r)) for r in res] + # Which residuals are already primal outputs? Check object id. + id_map = {id(p): i for i, p in enumerate(out_primals)} + out_fwd: list[int | None] = [id_map.get(id(r)) for r in res] + # Prune residuals not to include forwarded primal inputs or outputs. + res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None] + _store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd)) + return *res, *out_primals @lu.transformation2 def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents): @@ -157,6 +170,7 @@ def _linearize_jaxpr( primal_trace = pe.DynamicJaxprTrace(dbg) tangent_trace = pe.DynamicJaxprTrace(dbg) lin_trace = LinearizeTrace(primal_trace, tangent_trace) + tangent_trace.tag = lin_trace.tag def new_arg(trace, primal_aval, nz): primal = primal_trace.new_arg(primal_aval) @@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun, tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) + tangent_trace.tag = linearize_trace.tag tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] tracers = [t.full_lower() for t in tracers] with (core.set_current_trace(linearize_trace, check_leaks=True), @@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun, out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + consts = [c for c, used in zip(consts, used_consts) if used] out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else pe.PartialVal.known(zeros_like_aval(t.aval)) for t, nz in zip(out_tangents, out_nzs)] @@ -330,12 +349,34 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) + # Start with a forward pass to evaluate any side-effect-free JaxprEqns that + # only operate on primals. This is required to support primitives with + # linearization rules that include computations on the residuals. + lin_eqns = [] + for eqn in jaxpr.eqns: + # TODO (dfm): The effects check is probably stricter than necessary. + # Consider adding an allowlist of effects here. + if jaxpr.effects or any( + type(x) is not Literal and x not in primal_env for x in eqn.invars): + lin_eqns.append(eqn) + continue + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack), eqn.ctx.manager: + ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) + if eqn.primitive.multiple_results: + map(write_primal, eqn.outvars, ans) + else: + write_primal(eqn.outvars[0], ans) + ct_env: dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in jaxpr.eqns[::-1]: + for eqn in lin_eqns[::-1]: if eqn.primitive.ref_primitive: if eqn.primitive is core.mutable_array_p: val_var, = eqn.invars @@ -586,7 +627,7 @@ def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() tangent_aval = get_aval(tangent).strip_weak_type() - assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) + assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) @@ -641,6 +682,7 @@ class LinearizeTrace(Trace): return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in), dict(symbolic_zeros=symbolic_zeros)) + @partial(lu.wrap_init, debug_info=f_jvp.debug_info) def _f_jvp(primals, tangents): outs = f_jvp.call_wrapped(*primals, *tangents) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) @@ -651,7 +693,7 @@ class LinearizeTrace(Trace): nonzeros_in = [type(t) is not Zero for t in tangents_in] primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp( _f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, - f_jvp.debug_info, primals_in, {}) + primals_in, {}) with core.set_current_trace(self.tangent_trace): tangents_out = linearized(residuals, *tangents_in) @@ -690,53 +732,64 @@ class LinearizeTrace(Trace): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) - f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in, - f.debug_info) + f_primal, linearize_outs_thunk = linearize_subtrace( + f, self.tag, nzs_in, f.debug_info) if isinstance(call_primitive, core.MapPrimitive): - @as_hashable_function(closure=(linearize_outs_thunk)) + out_axes_thunk = params['out_axes_thunk'] + @as_hashable_function(closure=out_axes_thunk) def new_out_axes_thunk(): - residual_avals, _, _ = linearize_outs_thunk() - out_axes = params['out_axes_thunk']() - return (*(0 for _ in residual_avals), *out_axes) + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + out_axes = out_axes_thunk() + return (*(0 for _ in range(num_res_out)), *out_axes) primal_params = dict(params, out_axes_thunk=new_out_axes_thunk) else: primal_params = params all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params) - residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = all_primal_results[:num_residuals] - primals_out = all_primal_results[num_residuals:] + residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_primal_results[:num_res_out] + primals_out = all_primal_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] out_axes = params['out_axes_thunk']() residual_avals = map(get_aval, residuals) - new_in_axes = (*(0 for _ in residual_avals), + residual_axes = [in_axes[f1] if f1 is not None else + out_axes[f2] if f2 is not None else + 0 for f1, f2 in zip(in_fwd, out_fwd)] + new_in_axes = (*residual_axes, *(None for _ in range(len(env))), *(ax for ax, nz in zip(in_axes, nzs_in) if nz)) new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),) # NOTE: This assumes that the output tangents being zero is a # deterministic function of which input tangents were zero. - @as_hashable_function(closure=(new_out_axes)) + @as_hashable_function(closure=new_out_axes) def new_out_axes_thunk(): return new_out_axes - params = dict(params, - in_axes=new_in_axes, - out_axes_thunk=new_out_axes_thunk) + params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) update_params = call_linearize_param_updaters.get(call_primitive) - new_params = update_params(params, residual_avals, nzs_in) if update_params else params + num_new_args = len(residuals) + len(env) + new_params = update_params(params, num_new_args, nzs_in) if update_params else params + num_residuals = len(residual_avals) + @as_hashable_function(closure=(num_residuals, lin_jaxpr)) def f_tangent(*args): - residuals = args[:num_residuals] + consts = args[:num_residuals] nz_tangents = args[num_residuals:] - return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents) + return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) + # TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to + # avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses. + # Remove when we replace the pmap implementation. + f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = call_primitive.bind_with_trace( - self.tangent_trace, (lu.wrap_init(f_tangent, - debug_info=lin_jaxpr.debug_info), - *residuals, *nz_tangents_in), new_params) + self.tangent_trace, + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), + *residuals, *env, *nz_tangents_in), new_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] @@ -762,14 +815,14 @@ def fallback_linearize_rule(_prim: core.Primitive, msg = f"Differentiation rule for '{_prim}' not implemented" raise NotImplementedError(msg) debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params) - return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, - debug_jvp, primals, params) + return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp), + _prim.multiple_results, _nonzeros, False, False, + primals, params) -def linearize_from_jvp(jvp: Callable, +def linearize_from_jvp(jvp: lu.WrappedFun, multiple_results: bool, nonzeros: Sequence[bool], user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool, - debug_info: core.DebugInfo, primals, params): current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: @@ -792,13 +845,18 @@ def linearize_from_jvp(jvp: Callable, tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) for aval, nz in zip(tangent_avals, nonzeros)) with core.set_current_trace(trace): - out_primals, out_tangents = jvp(primals, tangent_args, **params) + out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params) if not multiple_results: out_primals = [out_primals] out_tangents = [out_tangents] out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals] + if any(p is None for p in out_primals): + raise ValueError( + "Linearization failed to produce known values for all output primals. " + "This is typically caused by attempting to differentiate a function " + "uses an operation that does not support reverse-mode autodiff.") out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known() for t in out_tangents] @@ -806,7 +864,7 @@ def linearize_from_jvp(jvp: Callable, out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -973,9 +1031,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): else: consts = () all_args, in_tree_def = tree_flatten((consts, args, ct)) - fun = lu.hashable_partial(lu.wrap_init(backward_pass, - debug_info=call_jaxpr.debug_info), - call_jaxpr, False) + fun = lu.hashable_partial(lu.wrap_init( + backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) update_params = call_transpose_param_updaters.get(primitive) if update_params: @@ -1013,9 +1070,8 @@ def map_transpose(primitive: core.Primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts # TODO(necula): use the right debug_info for the backwards pass - fun = lu.hashable_partial(lu.wrap_init(backward_pass, - debug_info=call_jaxpr.debug_info), - call_jaxpr, False) + fun = lu.hashable_partial(lu.wrap_init( + backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). @@ -1083,8 +1139,8 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) - f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), - nonzeros) + f_jvp, out_nonzeros = f_jvp_traceable( + jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7c10c7b8d..695044252 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None: if dtypes.float8_e8m0fnu is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get +if dtypes.float4_e2m1fn is not None: + _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get + def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the @@ -939,7 +942,7 @@ def sharded_aval(aval: core.AbstractValue, return aval if not isinstance(aval, (core.ShapedArray, core.DShapedArray)): raise NotImplementedError - return aval.update(sharding.shard_shape(aval.shape)) # type: ignore + return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore def eval_dynamic_shape(ctx: LoweringRuleContext, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6fde73705..ef8e02dda 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, - as_hashable_function, weakref_lru_cache, subs_list) + as_hashable_function, weakref_lru_cache, subs_list, + HashableFunction) map, unsafe_map = safe_map, map @@ -837,6 +838,11 @@ def tracers_to_jaxpr( # del getvar # needed to avoid cyclic-reference closure, apparently! return jaxpr, const_vals, env_vals +@weakref_lru_cache +def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr: + constvars, envvars = partition_list(which, jaxpr.constvars) + return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars]) + @weakref_lru_cache def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" @@ -1840,7 +1846,7 @@ def _inline_literals( class DynamicJaxprTrace(core.Trace): - __slots__ = ("frame",) + __slots__ = ("frame", "tag") def __init__(self, debug_info: core.DebugInfo): self.frame = JaxprStackFrame(debug_info) @@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def process_map(self, map_primitive, f: lu.WrappedFun, - tracers: Sequence[core.Tracer], params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] + with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( f, reduced_in_avals) + jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace( return tracer return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env else new_tracer(x) for x in jaxpr.outvars] + +# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the +# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's +# handling of pmap. Remove when we replace the pmap implementation. +def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]: + if (not f.transforms and type(f.f) is HashableFunction and + getattr(f.f, '_pmap_tag', None)): + _, jaxpr = f.f.closure + return convert_constvars_jaxpr(jaxpr), [] + return jaxpr, consts diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9a409e4cb..f97ee5414 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents): new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) -def _xla_call_linearize_update_params(params, residual_avals, nz_tangents): +def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents): donated_invars_prev = params['donated_invars'] - donated_invars = (*(False for _ in residual_avals), + donated_invars = (*(False for _ in range(num_new_inputs)), *(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz)) return dict(params, donated_invars=donated_invars) @@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum): elif self.name == 'OUT_SHARDING': return 'explicit output sharding' elif self.name == 'CONTEXT_DEVICES': - return 'devices' + return 'context mesh' return f'{self.name}' @@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys: in_layouts_leaves: tuple[Any, ...] | None = None out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None - use_resource_env: bool = False compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None @functools.cached_property diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cecd1cdc5..b75cbf6ac 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -27,7 +27,6 @@ from jax._src.lax import lax from jax._src import effects from jax._src import ad_util from jax._src import state -from jax._src import util from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef @@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts( # b[] <- 2.0 # in () } canonical_ref_indices = [] + canonical_non_ref_indices = [] canonical_refs: list[Any] = [] - tracer_id_to_canonical_id = {} - all_nonref_consts = [] + canonical_non_refs: list[Any] = [] + tracer_id_to_canonical_ref_id = {} + tracer_id_to_canonical_non_ref_id = {} canonical_ref_avals = [] - all_nonref_const_avals = [] + canonical_non_ref_avals = [] for consts, consts_avals in zip(all_consts, all_const_avals): ref_indices = [] - nonref_consts = [] - nonref_const_avals = [] + non_ref_indices = [] for c, aval in zip(consts, consts_avals): + tracer_id = id(c) if isinstance(aval, state.AbstractRef): - tracer_id = id(c) - if tracer_id not in tracer_id_to_canonical_id: + if tracer_id not in tracer_id_to_canonical_ref_id: canonical_id = len(canonical_refs) canonical_refs.append(c) - tracer_id_to_canonical_id[tracer_id] = canonical_id + tracer_id_to_canonical_ref_id[tracer_id] = canonical_id canonical_ref_avals.append(aval) - canonical_id = tracer_id_to_canonical_id[tracer_id] + canonical_id = tracer_id_to_canonical_ref_id[tracer_id] ref_indices.append(canonical_id) else: - nonref_consts.append(c) - nonref_const_avals.append(aval) - all_nonref_consts.append(nonref_consts) - all_nonref_const_avals.append(tuple(nonref_const_avals)) + if tracer_id not in tracer_id_to_canonical_non_ref_id: + canonical_id = len(canonical_non_refs) + canonical_non_refs.append(c) + tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id + canonical_non_ref_avals.append(aval) + canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id] + non_ref_indices.append(canonical_id) canonical_ref_indices.append(tuple(ref_indices)) + canonical_non_ref_indices.append(tuple(non_ref_indices)) - consts = [*canonical_refs, *util.concatenate(all_nonref_consts)] - jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,)) + consts = [*canonical_refs, *canonical_non_refs] + jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,)) for i, jaxpr in enumerate(jaxprs)) return jaxprs, consts, all_out_trees @weakref_lru_cache def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, - all_nonref_const_avals): + canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) newvar = core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) - for const_avals in all_nonref_const_avals] padded_ref_constvars = map(newvar, canonical_ref_avals) + padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var - const_prefix = util.concatenate(unused_const_vars[:i]) - const_suffix = util.concatenate(unused_const_vars[i + 1:]) - constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars, - *const_suffix] + for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): + padded_non_ref_constvars[canonical_id] = non_ref_var + constvars = [*padded_ref_constvars, *padded_non_ref_constvars] jaxpr = jaxpr.replace(constvars=constvars) effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, jaxpr.eqns) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index e2ad6ced1..63896cc2a 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, num_consts = len(consts) out_ = iter(out) + all_inputs = [*consts, *ops] out = [ - next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts]) + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 4c0bc6e6f..c7dee3e71 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -443,42 +443,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) num_trips, remainder = divmod(length, unroll) + if unroll != 1 and num_trips == 1 and remainder == 0: # In that case, we explicitly want to fully unroll the loop. Put everything # into the remainder block and avoid lowering to a while loop. num_trips, remainder = 0, length if unroll == 1: xss = xs_ - yss = _map(partial(_empty_array, (length,), None), y_avals) + yss = _map(partial(_empty_array, (length,), (None,)), y_avals) else: if remainder: if not reverse: xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_)) else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) - xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) + if num_trips: + xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] + yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals) + else: + yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals) - def cond_fun(while_carry): - i, _, _ = while_carry - return i < num_trips - def body_fun(while_carry): - i_, carry, yss = while_carry - i = num_trips - i_ - 1 if reverse else i_ - xs = [ - slicing.dynamic_index_in_dim( - xs, i, keepdims=False, allow_negative_indices=False - ) - for xs in xss - ] - carry, ys = inner(unroll, carry, xs) - yss = [ - slicing.dynamic_update_index_in_dim( - ys, upd, i, 0, allow_negative_indices=False - ) - for ys, upd in zip(yss, ys) - ] - return i_ + 1, carry, yss def inner(n, carry, xs): ys = [] if unroll == 1: @@ -493,10 +477,26 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, ys = list(reversed(ys)) if reverse else ys return carry, _map(_stack, zip(*ys)) + def body_fun(while_carry): + i_, carry, yss = while_carry + i = num_trips - i_ - 1 if reverse else i_ + xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False, + allow_negative_indices=False) + for xs in xss] + carry, ys = inner(unroll, carry, xs) + yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0, + allow_negative_indices=False) + for y, upd in zip(yss, ys)] + return i_ + 1, carry, yss + + def cond_fun(while_carry): + i, _, _ = while_carry + return i < num_trips + if num_trips: i = lax._const(num_trips, 0) _, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss)) - if unroll != 1: + if unroll != 1 and num_trips != 0: ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] else: ys = yss @@ -512,7 +512,7 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec)) + sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), out_sharding=sharding) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index adc632a82..99760099d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -16,6 +16,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence +import dataclasses import enum import functools from functools import partial @@ -2227,10 +2228,122 @@ def ragged_dot( Results: (m, n) shaped array with preferred_element_type element type. """ - return ragged_dot_p.bind(lhs, rhs, group_sizes, - precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, - group_offset=group_offset) + return ragged_dot_general( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) + + +@dataclasses.dataclass(frozen=True) +class RaggedDotDimensionNumbers(): + """Describes ragged, group, and dot dimensions for ragged dot general. + + Args: + dot_dimension_numbers: a tuple of tuples of sequences of ints of the form + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims))`. + lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged + dimensions. + rhs_group_dimensions: a sequence of ints indicating the 'rhs' group + dimensions. + """ + dot_dimension_numbers: DotDimensionNumbers + lhs_ragged_dimensions: Sequence[int] + rhs_group_dimensions: Sequence[int] + + def __init__( + self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions + ): + super().__setattr__( + 'dot_dimension_numbers', + tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers), + ) + super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions)) + super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions)) + + +def _from_maybe_ragged( + dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers, +) -> DotDimensionNumbers: + return ( + dot_dimension_numbers.dot_dimension_numbers + if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers) + else dot_dimension_numbers + ) + + +# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.) +_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], +) + + +def ragged_dot_general( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + group_offset: Array | None = None, +) -> Array: + """Ragged matrix multiplication. + + Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and + a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and + ``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally, + ``lhs`` is required to have one ragged dimension, and ``rhs`` may have at + most one group dimension. + + Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has + three modes, depending on the kind of the lhs ragged dimension: + 1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``, + and `x...` are the lhs non-contracting dims outer to the ragged dim. + 2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`. + Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and + ``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim. + 3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and + ``rhs``, and `x...` are the lhs batch dims outer to the ragged dim. + If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according + to the rules above. + + Args: + lhs: an array + rhs: an array + group_sizes: an array with integer element type + ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to + specify the dot dimension numbers, lhs ragged dimension, and rhs group + dimension. + precision: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + preferred_element_type: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + group_offset: Optional. (1,) shaped array that indicates the group in + group_sizes to start computing from. If not specified, defaults to [0]. + + Results: + An array whose shape is the same as that produced by `dot_general`, with an + extra leading dimension of size `g` in the case where the lhs ragged + dimension is a contracting dimension. + """ + return ragged_dot_general_p.bind( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None @@ -3723,10 +3836,11 @@ def _sin_complex(x): # 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2 a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = sn * csh, cs * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(sn, csh), mul(cs, snh) # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) @@ -3736,14 +3850,14 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) -def _sin_p_lin(nzs, x): +def _sin_lin(nzs, x): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_p_lin +ad.primitive_linearizations[sin_p] = _sin_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule @@ -3752,10 +3866,11 @@ def _cos_complex(x): # see also _sin_complex a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = cs * csh, -sn * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) def _cos_lowering(ctx, x): @@ -3769,28 +3884,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) +ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') -ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) acos_p = standard_unop(_float | _complex, 'acos') -ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x)))))) mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos)) def atan_impl(x): return atan2(x, _const(x, 1)) atan_p = standard_unop(_float | _complex, 'atan') -ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) +ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x)))) mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, - lambda g, x, y: g * (y / (square(x) + square(y))), - lambda g, x, y: g * -x / (square(x) + square(y))) + lambda g, x, y: mul(g, div(y, add(square(x), square(y)))), + lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y))))) mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2)) sinh_p = standard_unop(_float | _complex, 'sinh') @@ -3802,17 +3917,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh)) asinh_p = standard_unop(_float | _complex, 'asinh') -ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) +ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x))))) mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh)) acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, - lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) + lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x)))))) mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh)) atanh_p = standard_unop(_float | _complex, 'atanh') ad.defjvp(atanh_p, - lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) + lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x)))) mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh)) real_p = unop(_complex_basetype, _complex, 'real') @@ -3906,11 +4021,11 @@ def _square_complex(x): a, b = real(x), imag(x) # zero square(x).real is handled explicitly for abs(a)==abs(b) cases # where for finite a, 2 * a is non-finite: - zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b))) # equivalent to a**2 - b**2 but avoids overflow errors for large a # and large b cases: - re = (a - b) * (a + b) - im = a * b * 2 + re = mul(sub(a, b), add(a, b)) + im = mul(mul(a, b), _const(a, 2)) return select(zero_re, complex(_const(a, 0), im), complex(re, im)) def _square_lower_hlo(ctx, x): @@ -4591,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, out_sharding): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): msg = ("dot_general requires lhs dimension numbers to be nonnegative and " @@ -4652,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers) def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) batch_shape = tuple(lhs_shape[i] for i in lhs_batch) lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch))) lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) - rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch))) - rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch) + rhs_group = () + if isinstance(dimension_numbers, RaggedDotDimensionNumbers): + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_contract_or_batch_or_group = tuple( + sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) + ) + rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group) return batch_shape + lhs_tensored_shape + rhs_tensored_shape @@ -4721,7 +4841,7 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_sharding): + out_sharding, name: str = 'lax.dot_general'): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError del dimension_numbers # unused @@ -4742,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, result_dtype = rhs.dtype else: if lhs.dtype != rhs.dtype: - raise TypeError( - f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") + raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}') result_dtype = lhs.dtype has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)) return _maybe_upcast(result_dtype, preferred_element_type, @@ -4882,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): # explicitly present dimensions that this dot_general is zipping together. lbd, rbd = batch_dims assert lbd is not None or rbd is not None - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) + is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers) def bump_dims(dims, b): return tuple(np.add(dims, np.greater_equal(dims, b))) @@ -4906,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): elif (type(rbd) is int and lbd is None): # The right vmapped dimension becomes an additional tensor dimension in the # batched dot_general. - rhs_tensor = [d for d in range(rhs_ndim) - if d not in rhs_batch and d not in rhs_contract] + rhs_tensor = list( + remaining( + range(rhs_ndim), + rhs_batch, + rhs_contract, + dimension_numbers.rhs_group_dimensions if is_ragged_dot else [], + ) + ) result_batch_dim = (lhs_ndim - len(lhs_contract) + int(sum(np.less(rhs_tensor, rbd)))) rhs_batch = bump_dims(rhs_batch, rbd) @@ -4917,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): assert False new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) + if is_ragged_dot: + new_dimension_numbers = RaggedDotDimensionNumbers( + dot_dimension_numbers=new_dimension_numbers, + lhs_ragged_dimensions=bump_dims( + dimension_numbers.lhs_ragged_dimensions, lbd + ), + rhs_group_dimensions=bump_dims( + dimension_numbers.rhs_group_dimensions, rbd + ), + ) return new_dimension_numbers, result_batch_dim def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *, @@ -5008,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims): lbd, rbd = batch_dims return (lbd, rbd) -# DotDimensionNumbers used in the dot_general call for ragged_dot(). -_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([2, 0], [1, 0]), - ([], []), -) -_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([3, 1], [2, 1]), - ([0], [0]), -) ad.defbilinear(dot_general_p, _dot_general_transpose_lhs, _dot_general_transpose_rhs) @@ -5184,58 +5311,181 @@ for platform in ["cpu", "tpu"]: platform=platform) -def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: - if len(lhs.shape) == 3: - # Batched case - b, m, k = lhs.shape - b2, group_count, rk, n = rhs.shape - b3 = group_sizes.shape[0] - if b != b2: - raise TypeError( - f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and' - f' {b2}.' - ) - if b3 != b: - raise TypeError( - 'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got' - f' {b3} and {b}.' - ) - if k != rk: - raise TypeError( - f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and' - f' {rk}.' - ) - num_groups = group_sizes.shape[1] - if group_count != num_groups: - raise TypeError( - 'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got' - f' {group_count} and {num_groups}.' - ) - return (b, m, n) +class RaggedDotMode(enum.Enum): + RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n] + RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n] + RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n] - m, k = lhs.shape - group_count, rk, n = rhs.shape - if k != rk: - raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") - num_groups = group_sizes.shape[0] - if group_count != num_groups: - raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") - return (m, n) -def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, - **_) -> np.dtype: +def _ragged_dot_mode_and_dim( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> tuple[RaggedDotMode, int]: + assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1 + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + (lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers + lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch) + if lhs_ragged_dim in lhs_noncontracting: + mode = RaggedDotMode.RAGGED_NONCONTRACTING + elif lhs_ragged_dim in lhs_contracting: + mode = RaggedDotMode.RAGGED_CONTRACTING + elif lhs_ragged_dim in lhs_batch: + mode = RaggedDotMode.RAGGED_BATCH + else: + raise TypeError( + f'lhs_ragged_dim {lhs_ragged_dim} not found in ' + f'lhs_noncontracting {lhs_noncontracting}, ' + f'lhs_contracting {lhs_contracting}, or ' + f'lhs_batch {lhs_batch}.' + ) + return mode, lhs_ragged_dim + + +def _ragged_dot_mode( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> RaggedDotMode: + return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0] + + +def _is_ragged_contracting( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> bool: + return ( + _ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers) + == RaggedDotMode.RAGGED_CONTRACTING + ) + + +def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract): + batch, contract = map(list, (batch, contract)) + noncontract = remaining(range(rank), contract, batch) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + return batch + noncontract[: noncontract.index(ragged_dim)] + case RaggedDotMode.RAGGED_CONTRACTING: + return batch + contract[: contract.index(ragged_dim)] + case RaggedDotMode.RAGGED_BATCH: + return batch[: batch.index(ragged_dim)] + + +def _ragged_dot_general_shape_rule( + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + def _check_in_range(dim, rank, dim_name, arg_name): + if dim < 0 or dim >= rank: + raise TypeError( + f'ragged_dot_general requires {dim_name} numbers to be nonnegative ' + f'and less than the number of axes of the {arg_name} value, ' + f'got {dim} for {arg_name} of rank {rank}.' + ) + + # Validate the lhs ragged dimension, and find out which mode we're in. + if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1: + raise TypeError( + 'ragged_dot_general expects exactly one lhs ragged dimension.' + ) + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + _check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs') + mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers) + + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + + # Validate the shape of group_sizes, if it is something other than [g]. + if group_sizes.ndim == 0: + raise TypeError('expected rank of group_sizes to be >=1.') + if group_sizes.ndim != 1: + # Construct the expected shape [b...,x...,g] of group_sizes. + prefix_dims = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting + ) + expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims) + expected_gs_shape += (group_sizes.shape[-1],) + # TODO(pravnar): Permit other broadcastable shapes. + if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape): + raise TypeError( + 'expected group_sizes to have shape ' + f'{expected_gs_shape}, got {group_sizes.shape}.' + ) + num_groups = group_sizes.shape[-1] + + # Validate properties of the rhs group dimension(s). + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + match mode: + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + if len(rhs_group_dims) != 0: + raise TypeError( + 'ragged_dot_general requires zero group dimensions in the rhs ' + 'when lhs ragged dimension is contracting or batch.' + ) + case RaggedDotMode.RAGGED_NONCONTRACTING: + if len(rhs_group_dims) != 1: + raise TypeError( + 'ragged_dot_general requires exactly one rhs group dimension ' + 'when lhs ragged dimension is noncontracting.' + ) + rhs_group_dim = rhs_group_dims[0] + _check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs') + if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting: + raise TypeError( + 'ragged_dot_general requires rhs group dimension numbers to be ' + 'distinct from contracting and batch dimensions.' + ) + if rhs.shape[rhs_group_dim] != num_groups: + raise TypeError( + 'expected rhs group dimension size to be ' + f'{num_groups}, got {rhs.shape[rhs_group_dim]}.' + ) + + out_shape = _dot_general_shape_rule( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + ) + if mode == RaggedDotMode.RAGGED_CONTRACTING: + out_shape = (num_groups,) + out_shape + return out_shape + + +def _ragged_dot_general_dtype_rule( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): - raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") - # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. + raise TypeError( + 'ragged_dot_general requires that ' + 'group_sizes.dtype is subtype of np.integer.' + ) + # defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl. return _dot_general_dtype_rule( - lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type, - out_sharding=None) + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + name='lax.ragged_dot_general', + ) -def _ragged_dot_jvp_rule( - primals, tangents, precision, preferred_element_type, group_offset +def _ragged_dot_general_jvp_rule( + primals, tangents, ragged_dot_dimension_numbers, + precision, preferred_element_type, group_offset ): # note - we could ostensibly just get this by passing on the # value to ragged_dot below, but, this feels cleaner. @@ -5245,20 +5495,22 @@ def _ragged_dot_jvp_rule( dx, dy, _ = tangents # no tan on the gs # primal - primal_out = ragged_dot( + primal_out = ragged_dot_general( x, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) # tangent dx_out = ( - ragged_dot( + ragged_dot_general( dx, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) @@ -5266,73 +5518,127 @@ def _ragged_dot_jvp_rule( else _zeros(primal_out) ) dy_out = ( - ragged_dot( + ragged_dot_general( x, dy, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) if type(dy) is not ad_util.Zero else _zeros(primal_out) ) - tangent_out = dx_out + dy_out + tangent_out = add(dx_out, dy_out) return primal_out, tangent_out -def _ragged_to_dense(x, y, group_sizes): - from jax._src.lax import control_flow # avoid circular imports - shape = (y.shape[0], x.shape[0], x.shape[1]) - x = broadcast_in_dim(x, shape, [1, 2]) - iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = control_flow.cumsum(group_sizes) - group_starts = concatenate( - [_zeros(group_sizes)[:1], group_ends[:-1]], - dimension=0, - ) - group_ends = broadcast_in_dim(group_ends, shape, (0,)) - group_starts = broadcast_in_dim(group_starts, shape, (0,)) - mask = bitwise_and(group_starts <= iota, iota < group_ends) - x = select(mask, x, _zeros(x)) - return x - - -def _ragged_dot_transpose_rule( - ct, *operands, precision, preferred_element_type, group_offset +def _ragged_dot_general_transpose_rule( + ct, + x, + y, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + group_offset: Array | None, ): - x, y, gs = operands if group_offset is not None: raise NotImplementedError('Unimplemented group_offset support.') - if ad.is_undefined_primal(y): - grad_x = None - else: - y_t = _matrix_transpose(y) - grad_x = ragged_dot( - ct, - y_t, - gs, - precision=precision, - preferred_element_type=preferred_element_type, - ) + (x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers + x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x) + y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y) + x_kept = remaining(range(x_ndim), x_contract, x_batch) + y_group = ragged_dot_dimension_numbers.rhs_group_dimensions + y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group) + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + x_ndim, ragged_dot_dimension_numbers + ) - if ad.is_undefined_primal(x): - grad_y = None - else: - y = y.aval if ad.is_undefined_primal(y) else y - x_dense = _ragged_to_dense(x, y, group_sizes=gs) - ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) - dimension_numbers = (([1], [1]), ([0], [0])) - grad_y = dot_general( - x_dense, - ct_dense, - dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - ) + unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError( + f'Unimplemented {fn_name} for ragged dot general in mode ' + f'{ragged_dot_mode.name}.' + ) - return grad_x, grad_y, None + # This is a hack to ensure we continue to emit the `_matrix_transpose` for the + # grad_x case. This isn't strictly necessary since we have dot_dim_nums. + # TODO(pravnar): Remove this once we no longer care to emit the transpose. + _is_basic_ragged_dot = ( + x_ndim == 2 + and y_ndim == 3 + and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS + ) + + def grad_x_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) + dims = ( + ragged_dot_dimension_numbers + if _is_basic_ragged_dot + else RaggedDotDimensionNumbers( + dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)), + lhs_ragged_dimensions=[ + len(x_batch) + x_kept.index(lhs_ragged_dim) + ], + rhs_group_dimensions=y_group, + ) + ) + x_contract_sorted_by_y = list( + np.take(x_contract, np.argsort(y_contract)) + ) + unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_x_dims', mode) + return dims, unsorted_axes + + def grad_y_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept) + dims = RaggedDotDimensionNumbers( + dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)), + lhs_ragged_dimensions=[lhs_ragged_dim], + rhs_group_dimensions=[], + ) + y_contract_sorted_by_x = list( + np.take(y_contract, np.argsort(x_contract)) + ) + unsorted_axes = ( + list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept + ) + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_y_dims', mode) + return dims, unsorted_axes + + def _ragged_dot_grad(lhs, rhs, dims_fn, aval): + dims, unsorted_axes = dims_fn() + ragged_dot_general_out = ragged_dot_general( + lhs, rhs, group_sizes, dims, precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset) + result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes))) + if result.dtype != aval.dtype: + result = _convert_element_type(result, aval.dtype, aval.weak_type) + return result + + x_bar = ( + None + if ad.is_undefined_primal(y) + else _ragged_dot_grad(ct, + _matrix_transpose(y) if _is_basic_ragged_dot else y, + grad_x_dims, + x.aval) + ) + y_bar = ( + None + if ad.is_undefined_primal(x) + else _ragged_dot_grad(x, ct, grad_y_dims, y.aval) + ) + return x_bar, y_bar, None def _ragged_dot_batch_unpack_args(batched_args): @@ -5347,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims): return (lbd, rbd) -def _ragged_dot_invoke_prim( +def _ragged_dot_general_invoke_prim( group_sizes, lhs, rhs, - new_dimension_numbers, + new_ragged_dot_dimension_numbers, precision, preferred_element_type, out_sharding, ): del out_sharding - return ragged_dot( + return ragged_dot_general( lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) -def _ragged_dot_batch_rule( +def _ragged_dot_general_batch_rule( axis_data, batched_args, batch_dims, *, + ragged_dot_dimension_numbers, precision, preferred_element_type: DTypeLike | None, **_, ): - invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) - - return _dot_batch_rule( + invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2]) + batched_out, result_batch_dim = _dot_batch_rule( _ragged_dot_batch_unpack_args, _ragged_dot_batch_unpack_dims, invoke, axis_data, batched_args, batch_dims, - dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=None, ) + if _is_ragged_contracting(batched_args[0].ndim - 1, + ragged_dot_dimension_numbers): + result_batch_dim += 1 + return batched_out, result_batch_dim -ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, - _ragged_dot_dtype_rule, 'ragged_dot') -ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) -ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule -ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule -batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule -batching.skippable_batchers[ragged_dot_p] = lambda _: () +ragged_dot_general_p = standard_primitive( + _ragged_dot_general_shape_rule, + _ragged_dot_general_dtype_rule, + 'ragged_dot_general', +) +ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule +ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule +batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule +batching.skippable_batchers[ragged_dot_general_p] = lambda _: () -def _ragged_dot_impl( + +def _ragged_dot_general_impl( lhs: Array, rhs: Array, group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, group_offset: Array | None = None, @@ -5410,24 +5725,100 @@ def _ragged_dot_impl( if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - if len(lhs.shape) == 3: - ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS - ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0)) - else: - ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS - ragged_to_dense = _ragged_to_dense + def ragged_to_dense(x: Array, gs: Array, *, dim: int): + from jax._src.lax import control_flow # avoid circular imports + assert gs.ndim == 1 + shape = gs.shape + x.shape + x = broadcast_in_dim(x, shape, list(range(1, len(shape)))) + iota = broadcasted_iota(gs.dtype, shape, dim+1) + group_ends = control_flow.cumsum(gs) + group_starts = concatenate( + [_zeros(gs)[:1], group_ends[:-1]], + dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + x = select(mask, x, _zeros(x)) + return x - lhs = ragged_to_dense(lhs, rhs, group_sizes) + def batched_ragged_to_dense(dim, *x_in_axes: int): + if not x_in_axes: + return partial(ragged_to_dense, dim=dim) + x_axis, *rest = x_in_axes + decr = lambda d: d - 1 if d >= x_axis else d + return api.vmap( + batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]), + in_axes=(x_axis, 0), + ) - return dot_general( - lhs, - rhs, - dimension_numbers=ragged_dot_dims, + incr = lambda dims: [d + 1 for d in dims] + + # Expand the ragged `dim` of `x`, given its batching `axes`. + # The group axis from `gs` becomes the outermost axis of the result. + # Some examples: + # x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k] + # x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k] + def expand(x, dim, gs, *axes): + expanded = batched_ragged_to_dense(dim, *axes)(x, gs) + unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes)) + return transpose(expanded, np.argsort(unsorted_dims)) + + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + lhs.ndim, ragged_dot_dimension_numbers + ) + (l_contract, r_contract), (l_batch, r_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + l_prefix = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract + ) + + _dot_general = partial( + dot_general, precision=precision, preferred_element_type=preferred_element_type, ) + # TODO(pravnar): Permit other broadcastable shapes. + if group_sizes.ndim == 1: + group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix]) -mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + assert len(rhs_group_dims) == 1 + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + rhs, + dimension_numbers=( + (incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]), + (incr(l_batch), r_batch), + ), + ) + case RaggedDotMode.RAGGED_CONTRACTING: + rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)] + r_prefix = _ragged_dot_prefix_dims( + mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract + ) + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix), + dimension_numbers=( + (incr(l_contract), incr(r_contract)), + ([0] + incr(l_batch), [0] + incr(r_batch)), + ), + ) + case RaggedDotMode.RAGGED_BATCH: + return _dot_general( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + ) + + +mlir.register_lowering(ragged_dot_general_p, + mlir.lower_fun(_ragged_dot_general_impl, + multiple_results=False)) def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, @@ -5802,8 +6193,12 @@ def _concatenate_transpose_rule(t, *operands, dimension): def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) if bdim is not None) + spec = next(core.get_aval(op).sharding.spec[bdim] + for op, bdim in zip(batched_args, batch_dims) if bdim is not None) operands = [batching.moveaxis(op, bdim, 0) if bdim is not None - else broadcast(op, (size,)) + else broadcast( + op, (size,), out_sharding=core.get_aval(op).sharding.with_spec( + (spec, *core.get_aval(op).sharding.spec))) for op, bdim in zip(batched_args, batch_dims)] return concatenate(operands, dimension + 1), 0 diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index abd104293..c674401fb 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import gpu_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import gpu_sparse +from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec as P from jax._src.typing import Array, ArrayLike -# The following imports may be unused but they are needed to register the -# custom call targets defined in each module. -from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401 + +def register_module_custom_calls(module): + if hasattr(module, "registrations"): + for platform, targets in module.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + if hasattr(module, "batch_partitionable_targets"): + for name in module.batch_partitionable_targets(): + ffi.register_ffi_target_as_batch_partitionable(name) + + +register_module_custom_calls(gpu_linalg) +register_module_custom_calls(gpu_solver) +register_module_custom_calls(gpu_sparse) +register_module_custom_calls(lapack) # Top-level functions in alphabetical order. diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index bb1647d7e..c26de99c7 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1608,8 +1608,9 @@ def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: raise TypeError( "dynamic_update_slice update sharding must be equal to operand" - f" sharding, got update sharding {update.sharding} for operand sharding" - f" {operand.sharding}.") + " sharding, got update sharding" + f" {update.str_short(mesh_axis_types=True)} for operand sharding" + f" {operand.str_short(mesh_axis_types=True)}.") return operand.sharding def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index db0799c5a..c2e39a818 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -200,7 +200,7 @@ class _BaseMesh: _mesh_object_dict = {} # type: ignore -MeshAxisType = dict[AxisTypes, str | tuple[str, ...]] +MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]] class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 03740b659..287e8f039 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -45,8 +45,8 @@ RealNumeric = Any # Scalar jnp array or float @export @typing.runtime_checkable class Initializer(Protocol): - @staticmethod - def __call__(key: Array, + def __call__(self, + key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: raise NotImplementedError diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 8aa118959..9d745643b 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -547,6 +547,9 @@ def _einsum( if not last_contraction: dot_general_out_sharding = None elif out_sharding is not None and names != result_names: + if len(result_names) > len(out_sharding.spec): + out_sharding = out_sharding.with_spec( + out_sharding.spec._normalized_spec_for_aval(len(result_names))) spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding( diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 5d20b73af..585a5484a 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) +if dtypes.float4_e2m1fn is not None: + float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 1db2e0bde..e281c63ae 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -26,7 +26,7 @@ from jax._src import dtypes from jax._src.lax import lax from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding -from jax._src.util import safe_zip, safe_map +from jax._src.util import safe_zip, safe_map, set_module from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax.sharding import Sharding @@ -35,6 +35,8 @@ import numpy as np zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map +export = set_module('jax.numpy') + _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: @@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin return SingleDeviceSharding(device) else: return device + + +@export +def ndim(a: ArrayLike) -> int: + """Return the number of dimensions of an array. + + JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An integer specifying the number of dimensions of ``a``. + + Examples: + Number of dimensions for arrays: + + >>> x = jnp.arange(10) + >>> jnp.ndim(x) + 1 + >>> y = jnp.ones((2, 3)) + >>> jnp.ndim(y) + 2 + + This also works for scalars: + + >>> jnp.ndim(3.14) + 0 + + For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property: + + >>> x.ndim + 1 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("ndim", a, emit_warning=True) + return np.ndim(a) # NumPy dispatches to a.ndim if available. + + +@export +def shape(a: ArrayLike) -> tuple[int, ...]: + """Return the shape an array. + + JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An tuple of integers representing the shape of ``a``. + + Examples: + Shape for arrays: + + >>> x = jnp.arange(10) + >>> jnp.shape(x) + (10,) + >>> y = jnp.ones((2, 3)) + >>> jnp.shape(y) + (2, 3) + + This also works for scalars: + + >>> jnp.shape(3.14) + () + + For arrays, this can also be accessed via the :attr:`jax.Array.shape` property: + + >>> x.shape + (10,) + """ + # Deprecation warning added 2025-2-20. + check_arraylike("shape", a, emit_warning=True) + return np.shape(a) # NumPy dispatches to a.shape if available. + + +@export +def size(a: ArrayLike, axis: int | None = None) -> int: + """Return number of elements along a given axis. + + JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object + axis: optional integer along which to count elements. By default, return + the total number of elements. + + Returns: + An integer specifying the number of elements in ``a``. + + Examples: + Size for arrays: + + >>> x = jnp.arange(10) + >>> jnp.size(x) + 10 + >>> y = jnp.ones((2, 3)) + >>> jnp.size(y) + 6 + >>> jnp.size(y, axis=1) + 3 + + This also works for scalars: + + >>> jnp.size(3.14) + 1 + + For arrays, this can also be accessed via the :attr:`jax.Array.size` property: + + >>> y.size + 6 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("size", a, emit_warning=True) + return np.size(a, axis=axis) # NumPy dispatches to a.size if available. diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 913ef09cd..83b485107 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -25,8 +25,8 @@ from typing import Any, Callable, Protocol, Sequence import jax from jax import lax -from jax._src import api_util from jax._src import ad_util +from jax._src import api_util from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu @@ -351,7 +351,7 @@ def _pull_block_spec( jaxpr.constvars, jaxpr.invars, needed_invars, - jaxpr.eqns[:jaxpr.eqns.index(eqn)], + jaxpr.eqns[: jaxpr.eqns.index(eqn)], debug_info=jaxpr.debug_info, ) scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( @@ -426,6 +426,7 @@ def make_kernel_function( return tuple(s for s in shape if s is not None) _no_aval = object() + def _get_block_aval(bs, aval): if bs is pallas_core.no_block_spec or bs is None: return _no_aval @@ -441,10 +442,12 @@ def make_kernel_function( unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten( in_tree, invar_usages ) + def sds_like(x): if x is _no_aval: return _no_aval return jax.ShapeDtypeStruct(x.shape, x.dtype) + kernel_in_type = jax.tree.map( sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals) ) @@ -688,8 +691,10 @@ def _eltwise_eval_rule(prim, ctx, x, **params): def _eltwise_pull_rule( - prim: core.Primitive, ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, - **params + prim: core.Primitive, + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **params, ) -> Sequence[pallas_core.BlockSpec]: del prim, ctx, params return [block_spec] @@ -702,7 +707,9 @@ def _eltwise_usage_rule( return [used_out] -def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.BlockSpec: +def _bcast_block_spec( + block_spec: pallas_core.BlockSpec, i: int +) -> pallas_core.BlockSpec: def new_index_map(i, *args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) @@ -710,7 +717,9 @@ def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core. return idx new_block_shape = util.tuple_update(block_spec.block_shape, i, 1) - return pallas_core.BlockSpec(new_block_shape, functools.partial(new_index_map, i)) + return pallas_core.BlockSpec( + new_block_shape, functools.partial(new_index_map, i) + ) def _binop_usage_rule(prim, ctx, used_out: set[Usage]): @@ -945,7 +954,9 @@ def _dynamic_slice_rule( return block_indices new_block_spec = pallas_core.BlockSpec(block_spec.block_shape, new_index_map) - return [new_block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1) + return [new_block_spec] + [pallas_core.no_block_spec] * ( + len(ctx.avals_in) - 1 + ) @register_eval_rule(lax.concatenate_p) @@ -1348,7 +1359,8 @@ def _push_block_spec_jaxpr( return env[atom] def _write_block_spec( - atom: core.Atom, block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec + atom: core.Atom, + block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec, ): if isinstance(atom, core.Literal): return @@ -1374,7 +1386,9 @@ def _push_block_spec_jaxpr( util.safe_map(_write_block_spec, eqn.outvars, out_block_specs) out_block_specs = tuple(util.safe_map(_read_block_spec, jaxpr.outvars)) - valid_block_spec = [bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec][0] + valid_block_spec = [ + bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec + ][0] out_block_specs = tuple( valid_block_spec if obs is pallas_core.no_block_spec else obs for obs in out_block_specs @@ -1491,6 +1505,18 @@ def _convert_element_type_push_rule( return block_spec +@register_push_block_spec_rule(lax.select_n_p) +def _select_n_push_rule( + ctx: PushRuleContext, + *args: pallas_core.BlockSpec, +): + del ctx + block_specs = [b for b in args if b is not pallas_core.no_block_spec] + if len(block_specs) > 1: + raise NotImplementedError('select_n with multiple inputs not supported yet') + return block_specs[0] + + @register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p) def _custom_jvp_call_push_rule( ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_ @@ -1500,9 +1526,7 @@ def _custom_jvp_call_push_rule( @register_push_block_spec_rule(pjit.pjit_p) -def _pjit_push_rule( - ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_ -): +def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_): assert not jaxpr.consts return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs) diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index f98175510..3d36b8f3e 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -32,28 +32,42 @@ def _get_aval(x): return jax_core.raise_to_shaped(jax_core.get_aval(x)) -def fuse(f, *, physicalize: bool = False): +def fuse(f=None, *, physicalize: bool = False, debug: bool = False): """Fuses a function into a single fusable. + Args: + f: The function to fuse. + physicalize: (experimental) whether to physicalize the function. + debug: Whether to print debug information. + There should be a single call to a `fusable` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into the fusable and invoke it. """ - def wrapper(*args, **kwargs): - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - debug_info = api_util.debug_info('fuse', f, args, kwargs) - flat_fun, out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(f, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) - return wrapper + def decorator(f): + def wrapper(*args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + debug_info = api_util.debug_info("fuse", f, args, kwargs) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, debug_info=debug_info), in_tree + ) + flat_avals = [_get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + if debug: + print("Jaxpr before fusion:") + print(jaxpr) + out_tree = out_tree_thunk() + out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + + if physicalize: + wrapper = fusable_dtype.physicalize(wrapper) + return wrapper + + if f is not None: + return decorator(f) + return decorator _fusable: dict[jax_core.Primitive, Any] = {} diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 3ef324372..24e834104 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -159,6 +159,9 @@ py_library( ":core", ":primitives", "//jax", + "//jax:core", + "//jax:source_info_util", + "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index b0ada86cd..a731bfdfd 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -17,6 +17,7 @@ from collections.abc import Iterable, Sequence import dataclasses import enum import functools +import itertools import math import threading from typing import Any, Literal @@ -68,33 +69,140 @@ class TPUInterpretParams: Attributes: dma_execution_mode: If "eager", DMAs are executed as soon as they are - issued. If "on_wait", DMA reads or writes are only executed when a - device is waiting on a DMA semaphore that will be signaled when the read - or write is complete. + issued. If "on_wait", DMA reads or writes are only executed when a device + is waiting on a DMA semaphore that will be signaled when the read or write + is complete. Default: "on_wait". + detect_races: If True, a dynamic, happens-before race detector will be + used to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set + to True. + Default: False. + skip_floating_point_ops: If True, operations that produce only floating + point values will not be interpreted; instead, their results will be + replaced with arrays all of `jnp.inf`. Additionaly any floating point + operands to any operation will be replaced with (arrays of) `jnp.inf`. + Default: False. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" + detect_races: bool = False + skip_floating_point_ops: bool = False + + +VectorClock = np.ndarray + +# Conceptually, each DMA runs on its own, independent device. Representing +# this precisely would require vector clocks to have sizes linear in the number +# of DMAs. +# +# Instead, we use approximate vector clocks of fixed size. We assign each DMA +# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] -- +# and each operation of a DMA increments the corresponding coordinate in its +# vector clock. (So the "virtual" part of a vector clock is effectively +# counting, for each virtual device, the number of DMAs that happened-before +# the vector clock and were assigned to that virtual device.) +# +# If two approximate clocks are unordered, then their corresponding events are +# not ordered by the happens-before relation. So this approximation will not +# introduce any false positives in detecting data races. But we may fail to +# detect some true data races because there can be cases where two approximate +# clocks are ordered, and we will treat the corresponding events as ordered +# by the happens-before relation, but the corresponding events are not +# actually ordered. +NUM_VIRTUAL_DEVICES = 32 + +def make_vector_clock(num_devices: int) -> VectorClock: + del num_devices + return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32) + +def copy_vector_clock(x: VectorClock) -> VectorClock: + if x is None: + return None + return x.copy() + +def update_vector_clock(x: VectorClock, y: VectorClock): + x[:] = np.maximum(x, y) + +def lt(x: VectorClock, y: VectorClock) -> bool: + return bool((x <= y).all() & (x < y).any()) + +def ordered(x: VectorClock, y: VectorClock) -> bool: + return lt(x, y) | lt(y, x) + +def inc_vector_clock(x: VectorClock, device_id: int): + if device_id >= len(x): + raise ValueError(f'device_id={device_id} is out of range for x={x}') + assert device_id < len(x) + x[device_id] += 1 + class Semaphore: def __init__(self, semaphore_id=None): + shared_memory = _get_shared_memory() + self.id = semaphore_id + + # TODO(jburnim): Use one Condition variable per device. (Which will be + # easier to do when we're using single integer device IDs.) self.cv = threading.Condition() - # TODO(jburnim): Make this an array. - self.counts = collections.defaultdict(int) + self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32) - def signal(self, inc, device_id): + self.interpret_params = shared_memory.interpret_params + if self.interpret_params.detect_races: + # We associate a vector clock with each count in self.counts. Whenever + # self.counts[i] is signaled, self.clocks[i] is updated with the vector + # clock of the signaling device. Whenever device i successfully waits on + # self.counts[i], the vector clock of device i is updated with + # self.clocks[i]. + # + # TODO(jburnim): Model happens-before more precisely for the case where + # semaphores are over-signaled. + self.clocks = [None] * shared_memory.num_devices + + def signal(self, inc, device_id, clock): + """Signal the semaphore on `device_id` by `inc`. + + Args: + inc: A positive integer. The amount by which to increment the semaphore + on the target device. + device_id: The ID of the target device. + clock: The vector clock of the signaling device at the time of the signal. + """ + device_id = int(device_id) with self.cv: self.counts[device_id] += inc + if self.interpret_params.detect_races: + if self.clocks[device_id] is None: + self.clocks[device_id] = copy_vector_clock(clock) + else: + update_vector_clock(self.clocks[device_id], clock) self.cv.notify_all() - def wait(self, value, device_id, *, is_dma=False, interpret_params=None): + def read(self, device_id): + with self.cv: + return self.counts[device_id] + + def wait(self, value, device_id, *, is_dma=False): + device_id = int(device_id) + shared_memory = _get_shared_memory() + + # TODO(jburnim): + # - If the count is larger than value, raise an error? + # - If the count is equal to value, but there DMAs waiting to signal us, + # raise an error? + # Simple implementation for non-DMA semaphores. - if not is_dma or (interpret_params.dma_execution_mode == "eager"): + if not is_dma or (self.interpret_params.dma_execution_mode == "eager"): with self.cv: while self.counts[device_id] < value: self.cv.wait() self.counts[device_id] -= value + if self.interpret_params.detect_races: + clock = copy_vector_clock(self.clocks[device_id]) + if self.interpret_params.detect_races: + with shared_memory.lock: + update_vector_clock(shared_memory.clocks[device_id], clock) return # For DMA semaphores (when dma_execution_mode=='on_wait'), while our count @@ -106,10 +214,18 @@ class Semaphore: # up separate threads to handle executing DMAs. shared_memory = _get_shared_memory() while True: + clock = None with self.cv: if self.counts[device_id] >= value: self.counts[device_id] -= value - return + if self.interpret_params.detect_races: + clock = copy_vector_clock(self.clocks[device_id]) + else: + return + if clock is not None: + with shared_memory.lock: + update_vector_clock(shared_memory.clocks[device_id], clock) + return with shared_memory.lock: dma_queue = shared_memory.dmas_by_sem[self.id] @@ -121,14 +237,27 @@ class Semaphore: # Only execute the DMA as far as necessary to signal us. assert (dma.src_sem is self) or (dma.dst_sem is self) with dma.lock: + if dma.virtual_device_id is None: + dma.virtual_device_id = np.random.randint( + shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + if dma.state == DmaState.STARTED: # Do the read. - dma.data = get(dma.src_device_id, dma.src_memory_space, - dma.src_buffer_id, dma.src_transforms) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + dma.data = get(dma.src_device_id, + dma.src_memory_space, + dma.src_buffer_id, + dma.src_transforms, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: data_size = dma.data.itemsize * dma.data.size dma.src_sem.signal( - data_size, device_id=dma.src_device_id) + data_size, device_id=dma.src_device_id, clock=dma.clock) dma.state = DmaState.READ if dma.src_sem is self: @@ -138,11 +267,22 @@ class Semaphore: assert dma.state == DmaState.READ # Do the write. - store(dma.dst_device_id, dma.dst_memory_space, dma.dst_buffer_id, - dma.dst_transforms, dma.data) assert dma.dst_sem is self + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + store(dma.dst_device_id, + dma.dst_memory_space, + dma.dst_buffer_id, + dma.dst_transforms, + dma.data, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) data_size = dma.data.itemsize * dma.data.size - dma.dst_sem.signal(data_size, device_id=dma.dst_device_id) + dma.dst_sem.signal( + data_size, device_id=dma.dst_device_id, clock=dma.clock) dma.data = None dma.state = DmaState.COMPLETED @@ -168,17 +308,146 @@ class DMA: src_sem: Semaphore dst_sem: Semaphore + clock: VectorClock + source_info: source_info_util.SourceInfo | None = None state: DmaState = DmaState.STARTED data: np.ndarray | None = None + virtual_device_id: int | None = None lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) @dataclasses.dataclass -class SharedMemory: +class RaceDetectionState: num_devices: int + # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + reads: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list)) + + # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + writes: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list)) + + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + races_found: bool = False + +def _is_empty_slice(slice_or_idx: slice | int): + if isinstance(slice_or_idx, int) or (slice_or_idx == slice(None)): + return False + + # NOTE: All slices here will have known size. + start = int(slice_or_idx.start) if slice_or_idx.start is not None else 0 + stop = int(slice_or_idx.stop) + return (start < stop) + +def slices_overlap(slice_or_idx1: slice | int, slice_or_idx2: slice | int): + if isinstance(slice_or_idx1, int): + slice_or_idx1 = slice(slice_or_idx1, slice_or_idx1 + 1) + if isinstance(slice_or_idx2, int): + slice_or_idx2 = slice(slice_or_idx2, slice_or_idx2 + 1) + + if slice_or_idx1 == slice(None): + return _is_empty_slice(slice_or_idx2) + if slice_or_idx2 == slice(None): + return _is_empty_slice(slice_or_idx1) + + # TODO(jburnim): Handle non-zero steps. + assert (slice_or_idx1.step == 1) or (slice_or_idx1.step is None) + assert (slice_or_idx2.step == 1) or (slice_or_idx2.step is None) + + # NOTE: We are only comparing slices with known stops (and sizes). + # Do we need to handle zero-length slices? + return ((slice_or_idx1.start <= slice_or_idx2.start < slice_or_idx1.stop) + | (slice_or_idx2.start <= slice_or_idx1.start < slice_or_idx2.stop)) + +def ranges_overlap(range1: tuple[slice | int, ...], + range2: tuple[slice | int, ...]) -> bool: + return all(slices_overlap(r1, r2) for r1, r2 + in itertools.zip_longest(range1, range2, fillvalue=slice(None))) + +def check_read(device_id, clock, buffer_key, rnge, source_info=None): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with races.lock: + writes = races.writes[buffer_key] + num_writes = len(writes) + races.reads[buffer_key].append((device_id, clock, rnge, user_frame)) + + for i in range(num_writes): + write_device_id, write_clock, write_range, write_frame = writes[i] + if ordered(write_clock, clock): + continue + if not ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + with races.lock: + races.races_found = True + return + +def check_write(device_id, clock, buffer_key, rnge, source_info=None): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with races.lock: + writes = races.writes[buffer_key] + reads = races.reads[buffer_key] + num_writes = len(writes) + num_reads = len(reads) + races.writes[buffer_key].append((device_id, clock, rnge, user_frame)) + + # TODO(jburnim): For performance, we should also probably remove any + # conflicting reads and writes that happened-before the current write. + + for i in range(num_writes): + write_device_id, write_clock, write_range, write_frame = writes[i] + if ordered(write_clock, clock): + continue + if not ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + with races.lock: + races.races_found = True + break + + for i in range(num_reads): + read_device_id, read_clock, read_range, read_frame = reads[i] + if ordered(read_clock, clock): + continue + if not ranges_overlap(rnge, read_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}') + with races.lock: + races.races_found = True + return + + +@dataclasses.dataclass +class SharedMemory: + interpret_params: TPUInterpretParams + num_devices: int + clocks: list[VectorClock] + barrier: threading.Barrier + # (memory_space, buffer_id, device_id) -> NumPy array # TODO(jburnim): Handle Megacore. mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field( @@ -208,6 +477,7 @@ class SharedMemory: # Maybe for running multiple distinct interpreted computations in parallel? _shared_memory : SharedMemory | None = None _shared_memory_init_lock = threading.Lock() +races : RaceDetectionState | None = None def _get_shared_memory() -> SharedMemory: assert _shared_memory is not None @@ -218,15 +488,29 @@ def _clear_shared_memory(): with _shared_memory_init_lock: _shared_memory = None -def _initialize_shared_memory(device_id, num_devices): +def _initialize_shared_memory(device_id, num_devices, *, interpret_params): global _shared_memory del device_id num_devices = int(num_devices) with _shared_memory_init_lock: if _shared_memory is None: - _shared_memory = SharedMemory(num_devices=num_devices) + _shared_memory = SharedMemory( + interpret_params=interpret_params, + num_devices=num_devices, + clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], + barrier=threading.Barrier(num_devices)) assert _shared_memory.num_devices == num_devices + global races + races = RaceDetectionState(num_devices=num_devices) + +def _clean_up_shared_memory(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + shared_memory.barrier.wait() + if device_id == 0: + _clear_shared_memory() + def _validate(device_id): device_id = int(device_id) @@ -235,7 +519,9 @@ def _validate(device_id): for sem in shared_memory.sem.values(): with sem.cv: if sem.counts[device_id] != 0: - raise ValueError( + # TODO(jburnim): Make this raise an error, but in a way that doesn't + # cause other devices to hang later in `_clean_up_shared_memory`. + print( f'Semaphore {sem.id} has non-zero count for {device_id} at ' f'kernel exit: {sem.counts[device_id]}') @@ -248,6 +534,8 @@ def _allocate_buffer(device_id, memory_space, val): with shared_memory.lock: buffer_id = shared_memory.next_buffer_id[device_id] shared_memory.next_buffer_id[device_id] = buffer_id + 1 + # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, + # with zeros, or with the buffer ID). shared_memory.mem[(memory_space, buffer_id, device_id)] = val # TODO(jburnim): Raise an error if buffer_id is too big for int16. @@ -273,7 +561,7 @@ def _allocate_semaphores(device_id, shape): semaphore_id = shared_memory.next_semaphore_id[device_id] shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores for i in range(semaphore_id, semaphore_id + num_semaphores): - if not i in shared_memory.sem: + if i not in shared_memory.sem: shared_memory.sem[i] = Semaphore(i) # NOTE: For now, we use a relatively uncommon datatype (int16) for @@ -305,7 +593,7 @@ def get_barrier_semaphore(device_id, collective_id): shared_memory = _get_shared_memory() with shared_memory.lock: semaphore_id = collective_id - if not semaphore_id in shared_memory.sem: + if semaphore_id not in shared_memory.sem: shared_memory.sem[semaphore_id] = Semaphore() return np.int16(semaphore_id) @@ -314,10 +602,9 @@ def _transform_slice_or_index(slice_or_idx): if isinstance(slice_or_idx, int): return slice_or_idx else: - start, size, stride = ( - int(slice_or_idx.start), - int(slice_or_idx.size), - int(slice_or_idx.stride)) + start = int(slice_or_idx.start) + size = int(slice_or_idx.size) + stride = int(slice_or_idx.stride) return slice(start, start + size * stride, stride) def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): @@ -355,7 +642,8 @@ def _to_range(transforms) -> tuple[slice | int, ...]: ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) return ret -def get(device_id, memory_space, buffer_id, transforms): +def get(device_id, memory_space, buffer_id, transforms, *, + src_device_id=None, clock=None, source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -367,6 +655,10 @@ def get(device_id, memory_space, buffer_id, transforms): shared_memory = _get_shared_memory() with shared_memory.lock: read_range = _to_range(transforms) + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + if clock is None: + clock = copy_vector_clock(shared_memory.clocks[device_id]) buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] ret = buffer[read_range].copy() if transforms: @@ -377,9 +669,17 @@ def get(device_id, memory_space, buffer_id, transforms): raise ValueError( f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' f'reading [{read_range}] but bufer has shape {buffer.shape} .') - return ret -def store(device_id, memory_space, buffer_id, transforms, val): + if shared_memory.interpret_params.detect_races: + if src_device_id is None: + src_device_id = device_id + check_read(src_device_id, clock, (memory_space, buffer_id, device_id), + read_range, source_info=source_info) + + return ret + +def store(device_id, memory_space, buffer_id, transforms, val, *, + src_device_id=None, clock=None, source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -391,6 +691,11 @@ def store(device_id, memory_space, buffer_id, transforms, val): shared_memory = _get_shared_memory() with shared_memory.lock: + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + if clock is None: + clock = copy_vector_clock(shared_memory.clocks[device_id]) + buff = shared_memory.mem[(memory_space, buffer_id, device_id)] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. write_range = _to_range(transforms) @@ -402,7 +707,14 @@ def store(device_id, memory_space, buffer_id, transforms, val): f'writing [{write_range}] but buffer has shape {buff.shape} .') buff[write_range] = val -def swap(device_id, memory_space, buffer_id, transforms, val, mask): + if shared_memory.interpret_params.detect_races: + if src_device_id is None: + src_device_id = device_id + check_write(src_device_id, clock, (memory_space, buffer_id, device_id), + write_range, source_info=source_info) + +def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, + source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -417,6 +729,9 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask): shared_memory = _get_shared_memory() with shared_memory.lock: + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) buff = shared_memory.mem[(memory_space, buffer_id, device_id)] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. read_write_range = _to_range(transforms) @@ -447,31 +762,64 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask): mask[in_bounds_idx], raw_result, val[in_bounds_idx]) buff[read_write_range] = np.where( mask[in_bounds_idx], val[in_bounds_idx], raw_result) - return result + + if shared_memory.interpret_params.detect_races: + check_write(device_id, clock, (memory_space, buffer_id, device_id), + read_write_range, source_info=source_info) + return result def execute_dma(dma): + # TODO(jburnim) Eliminate duplicate code here and in Semaphore.wait. + shared_memory = _get_shared_memory() with dma.lock: assert dma.state == DmaState.STARTED - # Do the read. - dma.data = get(dma.src_device_id, dma.src_memory_space, - dma.src_buffer_id, dma.src_transforms) - data_size = dma.data.itemsize * dma.data.size + if dma.virtual_device_id is None: + # See comment in Semaphore.wait . + dma.virtual_device_id = np.random.randint( + shared_memory.num_devices, NUM_VIRTUAL_DEVICES) - # Signal the send semaphore. - if dma.src_sem is not None: - dma.src_sem.signal(data_size, device_id=dma.src_device_id) + # Do the read. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + dma.data = get(dma.src_device_id, + dma.src_memory_space, + dma.src_buffer_id, + dma.src_transforms, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + data_size = dma.data.itemsize * dma.data.size - # Do the write. - store(dma.dst_device_id, dma.dst_memory_space, dma.dst_buffer_id, - dma.dst_transforms, dma.data) + # Signal the send semaphore. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + if dma.src_sem is not None: + dma.src_sem.signal( + data_size, device_id=dma.src_device_id, clock=dma.clock) + dma.state = DmaState.READ - # Signal the receive semaphore. - if dma.dst_sem is not None: - dma.dst_sem.signal(data_size, device_id=dma.dst_device_id) + # Do the write. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + store(dma.dst_device_id, + dma.dst_memory_space, + dma.dst_buffer_id, + dma.dst_transforms, + dma.data, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) - dma.data = None - dma.state = DmaState.COMPLETED + # Signal the receive semaphore. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + if dma.dst_sem is not None: + dma.dst_sem.signal( + data_size, device_id=dma.dst_device_id, clock=dma.clock) + + dma.data = None + dma.state = DmaState.COMPLETED def print_memory(device_id): device_id = int(device_id) @@ -483,7 +831,7 @@ def print_memory(device_id): def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_memory_space, dst_id, dst_transforms, dst_sem_id, src_sem_id, dst_device_id, - *, interpret_params, source_info=None): + source_info=None): device_id = int(device_id) src_memory_space, src_id = int(src_memory_space), int(src_id) src_transforms = jax.tree.map(int, src_transforms) @@ -501,6 +849,10 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_sem = shared_memory.sem[dst_sem_id] src_sem = shared_memory.sem[src_sem_id] if src_sem_id is not None else None + clock = None + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) dma_id = shared_memory.next_dma_id shared_memory.next_dma_id += 1 @@ -510,31 +862,35 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_device_id, dst_memory_space, dst_id, dst_transforms, src_sem, dst_sem, + clock=clock, source_info=source_info, ) - if interpret_params.dma_execution_mode == 'on_wait': + if shared_memory.interpret_params.dma_execution_mode == 'on_wait': shared_memory.dmas_by_sem[dst_sem_id].append(dma) if src_sem_id is not None: shared_memory.dmas_by_sem[src_sem_id].append(dma) return - assert interpret_params.dma_execution_mode == 'eager' + assert shared_memory.interpret_params.dma_execution_mode == 'eager' execute_dma(dma) -def dma_wait(device_id, sem, size, *, interpret_params): +def dma_wait(device_id, sem_id, size): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) size = int(size) shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] - sem.wait(size, device_id, is_dma=True, interpret_params=interpret_params) + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + sem = shared_memory.sem[sem_id] + sem.wait(size, device_id, is_dma=True) -def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): +def semaphore_signal(device_id, sem_id, inc, target_device_id, + target_core_index): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) inc = int(inc) if target_device_id is None: target_device_id = device_id @@ -542,21 +898,28 @@ def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): target_device_id = int(target_device_id) if target_core_index is not None: - raise NotImplementedError('semaphore_signal with target_core_index') + if int(target_core_index) != 0: + raise NotImplementedError('semaphore_signal with target_core_index != 0') shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] - sem.signal(inc, target_device_id) + clock = None + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) + sem = shared_memory.sem[sem_id] + sem.signal(inc, target_device_id, clock) -def semaphore_wait(device_id, sem, value): +def semaphore_wait(device_id, sem_id, value): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) value = int(value) shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + sem = shared_memory.sem[sem_id] sem.wait(value, device_id) def _compute_transformed_shape_and_dtype(shape, dtype, transforms): @@ -597,16 +960,32 @@ def _is_any(memory_space): return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or (memory_space == pallas_core.MemorySpace.ANY)) +def _is_float(dtype): + return jnp.issubdtype(dtype, jnp.floating) + +_SENTINEL = jnp.inf + +@dataclasses.dataclass(frozen=True) +class Placeholder: + """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`.""" + shape: tuple[int, ...] + dtype: jnp.dtype + def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): env = {} def read(var): if isinstance(var, jax_core.Literal): - return var.val + result = var.val else: - return env[var] + result = env[var] + if isinstance(result, Placeholder): + result = jax.lax.full(result.shape, _SENTINEL, result.dtype) + return result def write(var, value): + if interpret_params.skip_floating_point_ops and _is_float(value.dtype): + value = Placeholder(value.shape, value.dtype) env[var] = value jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) @@ -627,236 +1006,274 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): _interpret_jaxpr, compiler_params=compiler_params, interpret_params=interpret_params) for eqn in jaxpr.eqns: - prim = eqn.primitive - invals = jax.util.safe_map(read, eqn.invars) + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): + prim = eqn.primitive + # We defer reading the values for `eqn.invars` into each of the branches + # of the if-elif-else statement below. This is because the else branch may + # not need to do any reads if `interpret_params.skip_floating_point_ops` + # is True. If this is the case, we want to avoid materializing the read + # array into the jaxpr when this function is traced. + deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) - if prim is primitives.load_p: - (ref, transforms, mask, _) = jax.tree.unflatten( - eqn.params['args_tree'], invals) - if mask is not None: - raise NotImplementedError('masked load_p') - out = callback.io_callback( - get, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - ordered=True) + if prim is primitives.load_p: + (ref, transforms, mask, _) = jax.tree.unflatten( + eqn.params['args_tree'], deferred_invals()) + if mask is not None: + raise NotImplementedError('masked load_p') + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + ordered=True) - elif prim is primitives.swap_p: - (ref, transforms, val, mask) = jax.tree.unflatten( - eqn.params['args_tree'], invals) - out = callback.io_callback( - swap, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - val, - mask, - ordered=True) + elif prim is primitives.swap_p: + (ref, transforms, val, mask) = jax.tree.unflatten( + eqn.params['args_tree'], deferred_invals()) + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + val, + mask, + ordered=True) - elif prim is lax.cond_p: - def _make_branch(jaxpr): - return lambda *args: _interpret(jaxpr, *args) - out = lax.switch( - invals[0], - [_make_branch(branch_jaxpr.jaxpr) - for branch_jaxpr in eqn.params['branches']], - *invals[1:]) + elif prim is mosaic_primitives.delay_p: + out = [] - elif prim is lax.scan_p: - consts, init_carry, xs = split_list( - invals, [eqn.params['num_consts'], eqn.params['num_carry']]) - def _scan_body(c, a): - return split_list( - _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), - [eqn.params['num_carry']]) - carry, out = lax.scan(_scan_body, init_carry, xs=xs, - length=eqn.params.get('length', None)) - out = carry + out + elif prim is lax.cond_p: + def _make_branch(jaxpr): + return lambda *args: _interpret(jaxpr, *args) + invals = deferred_invals() + out = lax.switch( + invals[0], + [_make_branch(branch_jaxpr.jaxpr) + for branch_jaxpr in eqn.params['branches']], + *invals[1:]) - elif prim is lax.while_p: - cond_consts, body_consts, init_vals = split_list( - invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']]) - out = lax.while_loop( - lambda args: _interpret( - eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], - lambda args: _interpret( - eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), - init_vals) + elif prim is lax.scan_p: + consts, init_carry, xs = split_list( + deferred_invals(), + [eqn.params['num_consts'], eqn.params['num_carry']], + ) + def _scan_body(c, a): + return split_list( + _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), + [eqn.params['num_carry']]) + carry, out = lax.scan(_scan_body, init_carry, xs=xs, + length=eqn.params.get('length', None)) + out = carry + out - elif prim is for_loop.for_p: - raise NotImplementedError('for_p') + elif prim is lax.while_p: + cond_consts, body_consts, init_vals = split_list( + deferred_invals(), + [eqn.params['cond_nconsts'], eqn.params['body_nconsts']], + ) + out = lax.while_loop( + lambda args: _interpret( + eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], + lambda args: _interpret( + eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), + init_vals) - elif prim is pjit.pjit_p: - def f(*args, jaxpr): - return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) - in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) - new_jaxpr = _to_jaxpr( - lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), - debug_info=eqn.params['jaxpr'].jaxpr.debug_info), - in_avals) - out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) + elif prim is for_loop.for_p: + raise NotImplementedError('for_p') - elif prim is primitives.run_scoped_p: - # Allocate a buffer or semaphore for each element of - # eqn.params['jaxpr'].invars . - allocs = [] - for v in eqn.params['jaxpr'].invars: - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - allocs.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), - device_id, - v.aval.shape, - ordered=True)) + elif prim is pjit.pjit_p: + def f(*args, jaxpr): + return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) + invals = deferred_invals() + in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) + new_jaxpr = _to_jaxpr( + lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), + debug_info=eqn.params['jaxpr'].jaxpr.debug_info), + in_avals) + out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) + + elif prim is primitives.run_scoped_p: + # Allocate a buffer or semaphore for each element of + # eqn.params['jaxpr'].invars . + allocs = [] + for v in eqn.params['jaxpr'].invars: + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + allocs.append(callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), + device_id, + v.aval.shape, + ordered=True)) + else: + allocs.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + primitives.uninitialized_value(v.aval.shape, v.aval.dtype), + ordered=True)) + + out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) + + for a in allocs: + if isinstance(a, tuple): + callback.io_callback( + _deallocate_buffer, + None, + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) + else: + # TODO(jburnim): De-allocate semaphores. + # callback.io_callback( + # _deallocate_semaphores, + # None, + # device_id, + # a, + # ordered=True) + pass + + elif prim is state_primitives.get_p: + invals = deferred_invals() + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[1:]), + ordered=True) + + elif prim is state_primitives.swap_p: + invals = deferred_invals() + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[2:]), + invals[1], + None, + ordered=True) + + elif prim is mosaic_primitives.dma_start_p: + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) + (orig_src_ref, _, orig_dst_ref, *_ + ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) + callback.io_callback( + functools.partial(dma_start, source_info=eqn.source_info), + (), + device_id, + TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + src, src_transforms, + TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + dst, dst_transforms, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + state_discharge.transform_array(src_sem, src_sem_transforms), + target_device_id, + ordered=True) + out = [] + + elif prim is mosaic_primitives.dma_wait_p: + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) + read_shape, read_dtype = _compute_transformed_shape_and_dtype( + eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) + callback.io_callback( + dma_wait, + (), + device_id, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + math.prod(read_shape) * read_dtype.itemsize, + ordered=True) + out = [] + + elif prim is mosaic_primitives.get_barrier_semaphore_p: + out = callback.io_callback( + get_barrier_semaphore, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + compiler_params['mosaic']['collective_id'], + ordered=True) + + elif prim is mosaic_primitives.semaphore_signal_p: + sem, sem_transforms, inc, target_device_id, core_index = ( + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) + callback.io_callback( + semaphore_signal, + (), + device_id, + state_discharge.transform_array(sem, sem_transforms), + inc, + target_device_id, + core_index, + ordered=True) + out = [] + + elif prim is mosaic_primitives.semaphore_wait_p: + sem, sem_transforms, value = ( + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) + callback.io_callback( + semaphore_wait, + (), + device_id, + state_discharge.transform_array(sem, sem_transforms), + value, + ordered=True) + out = [] + + elif prim is primitives.atomic_rmw_p: + raise NotImplementedError('atomic_rmw_p') + + elif prim is primitives.atomic_cas_p: + raise NotImplementedError('atomic_cas_p') + + else: + if interpret_params.skip_floating_point_ops and all( + _is_float(ovar.aval.dtype) for ovar in eqn.outvars + ): + # Skip `prim.bind` since `prim` only produces floating-point values. + # It is safe to populate `out` with avals since mapping `write` over + # `out` below only relies on the shape and dtype (for writing + # `Placeholder`s). + out = [ovar.aval for ovar in eqn.outvars] + if not prim.multiple_results: + out = out[0] else: - allocs.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - primitives.uninitialized_value(v.aval.shape, v.aval.dtype), - ordered=True)) + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + out = prim.bind(*subfuns, *deferred_invals(), **bind_params) - out = _interpret(eqn.params['jaxpr'], *invals, *allocs) - - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: - # TODO(jburnim): Delete semaphores. - # callback.io_callback( - # _deallocate_semaphores, - # None, - # device_id, - # a, - # ordered=True) - pass - - elif prim is state_primitives.get_p: - out = callback.io_callback( - get, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[1:]), - ordered=True) - - elif prim is state_primitives.swap_p: - out = callback.io_callback( - swap, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[2:]), - invals[1], - None, - ordered=True) - - elif prim is mosaic_primitives.dma_start_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - (orig_src_ref, _, orig_dst_ref, *_ - ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) - callback.io_callback( - functools.partial(dma_start, interpret_params=interpret_params, - source_info=eqn.source_info), - (), - device_id, - TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space], - src, src_transforms, - TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space], - dst, dst_transforms, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - state_discharge.transform_array(src_sem, src_sem_transforms), - target_device_id, - ordered=True) - out = [] - - elif prim is mosaic_primitives.dma_wait_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) - read_shape, read_dtype = _compute_transformed_shape_and_dtype( - eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) - callback.io_callback( - functools.partial(dma_wait, interpret_params=interpret_params), - (), - device_id, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - math.prod(read_shape) * read_dtype.itemsize, - ordered=True) - out = [] - - elif prim is mosaic_primitives.get_barrier_semaphore_p: - out = callback.io_callback( - get_barrier_semaphore, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - compiler_params['mosaic']['collective_id'], - ordered=True) - - elif prim is mosaic_primitives.semaphore_signal_p: - sem, sem_transforms, inc, target_device_id, core_index = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - callback.io_callback( - semaphore_signal, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - inc, - target_device_id, - core_index, - ordered=True) - out = [] - - elif prim is mosaic_primitives.semaphore_wait_p: - sem, sem_transforms, value = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) - callback.io_callback( - semaphore_wait, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - value, - ordered=True) - out = [] - - elif prim is primitives.atomic_rmw_p: - raise NotImplementedError('atomic_rmw_p') - - elif prim is primitives.atomic_cas_p: - raise NotImplementedError('atomic_cas_p') - - else: - # TODO(jburnim): Add special handling for nested pallas_call_p. - # (For example, so that buffers can be shared with nested Pallas calls.) - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - out = prim.bind(*subfuns, *invals, **bind_params) - - out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) + out = out if prim.multiple_results else [out] + jax.util.safe_map(write, eqn.outvars, out) return jax.util.safe_map(read, jaxpr.outvars) @@ -961,7 +1378,8 @@ def interpret_pallas_call( tuple(lax.axis_index(s) for s in axis_sizes.keys()), axis_sizes) callback.io_callback( - _initialize_shared_memory, + functools.partial( + _initialize_shared_memory, interpret_params=interpret_params), (), device_id, num_devices, @@ -986,7 +1404,9 @@ def interpret_pallas_call( output_buffer_ids = [] output_buffer_shapes = [] output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, args, input_output_aliases) + grid_mapping.block_mappings_output, + scalars + input_args, + input_output_aliases) num_outputs = grid_mapping.num_outputs output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] for out_val, bs in zip(output_vals, output_block_shapes): @@ -999,13 +1419,12 @@ def interpret_pallas_call( TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], padded_val, ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, # outputs, scratch). io_alias_map = dict(input_output_aliases) oi_alias_map = {v: k for k, v in input_output_aliases} kernel_buffer_ids = [] - for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1107,6 +1526,8 @@ def interpret_pallas_call( input_args[j], is_indexing_dim[j]) assert(sliced_val.shape == var.aval.shape) callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. store, (), device_id, @@ -1129,6 +1550,8 @@ def interpret_pallas_call( if _is_any(var.aval.memory_space): continue kernel_output_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. get, var.aval, device_id, @@ -1144,6 +1567,8 @@ def interpret_pallas_call( shape=output_vals[j].shape, int_indexer_shape=()) callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. store, (), device_id, @@ -1165,6 +1590,8 @@ def interpret_pallas_call( # Read the output from the allocated output buffers. ret = [ callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. get, val, device_id, @@ -1178,33 +1605,22 @@ def interpret_pallas_call( output_vals, output_buffer_ids, output_buffer_shapes) ] - for buffer_id in output_buffer_ids: - callback.io_callback( - _deallocate_buffer, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - ordered=True) - for buffer_id, var in zip(kernel_buffer_ids, jaxpr.invars): - if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - pass - else: - callback.io_callback( - _deallocate_buffer, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - buffer_id, - ordered=True) + callback.io_callback( + _validate, + (), + device_id, + ordered=True) - # TODO(jburnim): Either validate just the semaphores allocated for this - # pallas_call, or only do validation if we are exiting a top-level - # (i.e., not nested) pallas_call. - # callback.io_callback( - # _validate, - # (), - # device_id, - # ordered=True) + # For now, when we're done with a pallas_call, we delete the shared memory. + # We use a barrier to ensure that all devices are done running the kernel. + # + # TODO(jburnim): Get rid of this barrier. And figure out how this should + # work if we want to invoke successive pallas_calls that use the same + # shared memory. + callback.io_callback( + _clean_up_shared_memory, + (), + device_id, + ordered=True) return ret diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 46294a766..9bc20ed2c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3222,6 +3222,14 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule +def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx): + if not isinstance(x.type.element_type, ir.F32Type): + raise ValueError("Only float32 is supported.") + return tpu.reciprocal(x, approx=approx) + + +lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule + def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0a79de771..2044d3d18 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -207,6 +207,8 @@ class BufferedRef: is_accumulator: whether this BufferedRef is an accumulator. is_input_output: whether this BufferedRef is an input/output without automatic accumulation. + swap: Tracks whether the BufferedRef slots need to be swapped before next + copy. """ spec: pl.BlockSpec # static metadata dtype: Any # static metadata @@ -214,9 +216,14 @@ class BufferedRef: window_ref: REF | None accum_ref: REF | None current_slot: ArrayRef | None + # TODO(ramiroleal): Unused by class. Remove argument from + # BufferedRef instantiations. next_slot: ArrayRef | None sem_recvs: SemaphoreTuple | None sem_sends: SemaphoreTuple | None + # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid + # using this ref. + swap: ArrayRef | None def tree_flatten(self): return ( @@ -227,6 +234,7 @@ class BufferedRef: self.next_slot, self.sem_recvs, self.sem_sends, + self.swap, ), (self.spec, self.dtype, self.buffer_type), ) @@ -240,7 +248,7 @@ class BufferedRef: return BufferType @classmethod - def create(cls, spec, dtype, buffer_type) -> BufferedRef: + def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: """Create a BufferedRef. Args: @@ -248,6 +256,7 @@ class BufferedRef: dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. + needs_swap_ref: whether a swap slots tracker needs to be allocated. Returns: Initialized BufferedRef @@ -271,6 +280,7 @@ class BufferedRef: next_slot=None, sem_recvs=None, sem_sends=None, + swap=None, ) else: memory_space = SMEM if spec.memory_space == SMEM else VMEM @@ -281,7 +291,7 @@ class BufferedRef: window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), - next_slot=SMEM((1,), jnp.int32), + next_slot=None, sem_recvs=( None if buffer_type is BufferType.OUTPUT @@ -292,23 +302,24 @@ class BufferedRef: if buffer_type is BufferType.INPUT else SemaphoreType.DMA((2,)) ), + swap=SMEM((1,), jnp.bool) if needs_swap_ref else None, ) @classmethod - def input(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.INPUT) + def input(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref) @classmethod - def output(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.OUTPUT) + def output(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref) @classmethod - def accumulator(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.ACCUMULATOR) + def accumulator(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref) @classmethod - def input_output(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.INPUT_OUTPUT) + def input_output(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref) @property def block_shape(self): @@ -329,7 +340,7 @@ class BufferedRef: if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: - return self.window_ref.at[(self.current_slot[0], *buffer_slice)] + return self.window_ref.at[(self.current_slot_index, *buffer_slice)] @property def is_input(self): @@ -355,6 +366,14 @@ class BufferedRef: def is_input_output(self): return self.buffer_type == BufferType.INPUT_OUTPUT + @property + def current_slot_index(self): + return self.current_slot[0] + + @property + def next_slot_index(self): + return lax.rem(self.current_slot_index + 1, 2) + def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" if self.memory_space == VMEM: @@ -373,12 +392,15 @@ class BufferedRef: """Initialize slot indices.""" if self.memory_space == VMEM: return self.current_slot[0] = 0 - self.next_slot[0] = 0 + if self.swap is not None: + self.swap[0] = False def swap_slots(self): """Switch to the next slot.""" if self.memory_space == VMEM: return - self.current_slot[0] = self.next_slot[0] + self.current_slot[0] = self.next_slot_index + if self.swap is not None: + self.swap[0] = False def get_dma_slice(self, src_shape, src_dtype, grid_indices): # We need to handle blocks that might go OOB in the src array. An in bounds @@ -441,8 +463,9 @@ class BufferedRef: """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return - next_slot = lax.rem(self.current_slot[0] + 1, 2) - self.next_slot[0] = next_slot + if self.swap is not None: + self.swap[0] = True + next_slot = self.next_slot_index src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( @@ -455,8 +478,9 @@ class BufferedRef: """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return - slot = self.current_slot[0] - self.next_slot[0] = lax.rem(slot + 1, 2) + if self.swap is not None: + self.swap[0] = True + slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( @@ -471,7 +495,7 @@ class BufferedRef: if self.memory_space == VMEM: return src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) - current_slot = self.current_slot[0] + current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter self.window_ref.at[current_slot].at[ @@ -484,7 +508,8 @@ class BufferedRef: """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return - prev_slot = lax.rem(self.current_slot[0] + 1, 2) + # In a double buffer, previous slot is the same as next slot. + prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( @@ -671,10 +696,7 @@ class Scheduler: def _start(): if buffered_ref.is_input: buffered_ref.copy_in(src_ref, self.indices) - - # In the prologue this makes it so we wait on the prologue copy to finish. - # In other iterations this is the regular swap. - buffered_ref.swap_slots() + buffered_ref.swap_slots() def wait_in(self, buffered_ref, src_ref, schedule=None): if schedule is None: @@ -780,9 +802,32 @@ class Scheduler: @self._named_scope("ep_finalize") def _end(): if buffered_ref.is_output: - buffered_ref.swap_slots() # formally correct, not actually necessary. buffered_ref.wait_out(dst_ref, self.indices) + def swap_slots(self, buffered_ref, hbm_ref, schedule=None): + if buffered_ref.swap is not None: + swap = buffered_ref.swap[0] + else: + # If we are not using an SMEM `swap` tensor to keep track of + # swaps needed, then all the copies into and out of BufferedRefs + # are done by direct calls to the `copy_in` and `copy_out` + # methods in the pipeline loop. To determine if the BufferedRef + # needs a swap of slots, we recalculate the copy-in/copy-out + # conditions. + if schedule is None: + schedule = _default_schedule + pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) + pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) + + copied_in = pred_in & buffered_ref.is_input & ~self.last_step + copied_out = pred_out & buffered_ref.is_output + swap = copied_in | copied_out + + @pl.when(swap) + @self._named_scope("ep_swap") + def _swap(): + buffered_ref.swap_slots() + # END SCHEDULE -------------------------------------------------------------- @@ -875,6 +920,7 @@ def make_pipeline_allocations( in_specs=None, out_specs=None, should_accumulate_out=False, + needs_swap_ref=True, ): """Create BufferedRefs for the pipeline. @@ -887,6 +933,7 @@ def make_pipeline_allocations( out_specs: output pallas block specs should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. + needs_swap_ref: whether a swap slots tracker needs to be allocated. Returns: A list of BufferedRefs, one corresponding to each ref specified in the @@ -905,12 +952,12 @@ def make_pipeline_allocations( in_refs = refs[:num_in_specs] out_refs = refs[num_in_specs:] def make_input_bref(in_spec, in_ref): - return BufferedRef.input(in_spec, in_ref.dtype) + return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref) in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs) def make_output_bref(out_spec, out_ref, accumulate): if accumulate: - return BufferedRef.accumulator(out_spec, out_ref.dtype) - return BufferedRef.output(out_spec, out_ref.dtype) + return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref) + return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref) out_brefs = jax.tree.map( make_output_bref, out_specs, out_refs, should_accumulate_out) return (*in_brefs, *out_brefs) @@ -1109,6 +1156,14 @@ def emit_pipeline( scratches = () if allocations is None: # run with inline scoped allocations + + # Prefetch and postyeet are arbitrary functions that can copy + # into or out of any of the BufferedRefs. Thus, we need a ref + # for the scheduler to mark when the prefetch or postyeet + # functions perform a copy and the slots need to be + # swapped. Without prefetch and postyeet, the swapping logic can + # be performed without the need for state. + needs_swap_ref = prefetch is not None or postyeet is not None return primitives.run_scoped( lambda allocations: pipeline( *refs, @@ -1125,7 +1180,9 @@ def emit_pipeline( *refs, in_specs=in_specs, out_specs=out_specs, - should_accumulate_out=should_accumulate_out), + should_accumulate_out=should_accumulate_out, + needs_swap_ref=needs_swap_ref, + ), ) if isinstance(allocations, list): allocations = tuple(allocations) @@ -1184,6 +1241,8 @@ def emit_pipeline( lax.cond(step == 0, lambda: postyeet(*brefs, scheduler), lambda: None) + + map_brefs(scheduler.swap_slots, brefs, refs, schedule) map_brefs(scheduler.finalize, brefs, refs, schedule) return _next_index(indices, grid) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dcdfe62bb..43b9008bb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,7 +17,7 @@ from __future__ import annotations import collections -from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools @@ -25,6 +25,7 @@ import math from typing import Any, Protocol, cast import jax +from jax import api_util from jax import lax from jax._src import core as jax_core from jax._src import linear_util as lu @@ -36,6 +37,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect @@ -837,6 +839,29 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: ) +def _lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + + def lowering_rule(ctx: LoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init( + fun + if multiple_results + else lambda *args, **params: (fun(*args, **params),), + params, + debug_info=api_util.debug_info( + "Pallas Mosaic GPU lower_fun", fun, args, params + ), + ) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + out = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, args, consts + ) + return out if multiple_results else out[0] + + return lowering_rule + + @register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): @@ -1162,6 +1187,9 @@ def _convert_element_type_lowering_rule_wg( cur_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype) new_dtype = mgpu_utils.dtype_to_ir_type(new_dtype) + if cur_dtype == new_dtype: + return x + if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8: raise NotImplementedError("Conversion involving sub-byte types unsupported") @@ -1170,7 +1198,29 @@ def _convert_element_type_lowering_rule_wg( from_integer = ir.IntegerType.isinstance(cur_dtype) to_integer = ir.IntegerType.isinstance(new_dtype) if from_float and to_float: - if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + cur_ty_width = ir.FloatType(cur_dtype).width + new_ty_width = ir.FloatType(new_dtype).width + if cur_ty_width == new_ty_width: + # There is no instruction to perform conversions between two float types + # of the same width. Go through the next-larger standard type. + # TODO(bchetioui): support conversions between float types of width 8. + # Which larger type to pick will depend on the number of bits in the + # smallest exponent. + if cur_ty_width != 16: + raise NotImplementedError( + "Conversion between float types of width other than 16 not" + " supported" + ) + larger_ty = ir.F32Type.get() + if x_aval.shape: + upcast_ty = ir.VectorType.get(x_aval.shape, larger_ty) + else: + upcast_ty = larger_ty + + def convert(ty, x): + return arith_dialect.truncf(ty, arith_dialect.extf(upcast_ty, x)) + + elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: convert = arith_dialect.truncf else: convert = arith_dialect.extf @@ -1190,10 +1240,26 @@ def _convert_element_type_lowering_rule_wg( else: convert = arith_dialect.uitofp elif from_float and to_integer: + dst_width = mgpu_utils.bitwidth(new_dtype) + # We clamp the float value to the min/max integer destination value + # in order to match JAX/XLA casting behavior. Note that this differs + # from numpy casting behavior. if mgpu_utils.is_signed(y_aval.dtype): + maxint = 2 ** (dst_width - 1) - 1 + minint = -(2 ** (dst_width - 1)) convert = arith_dialect.fptosi else: + maxint = 2**dst_width - 1 + minint = 0 convert = arith_dialect.fptoui + + maxint = _ir_constant(maxint, cur_dtype) + minint = _ir_constant(minint, cur_dtype) + if x_aval.shape: + maxint = vector_dialect.splat(x.type, maxint) + minint = vector_dialect.splat(x.type, minint) + x = arith_dialect.minimumf(x, maxint) + x = arith_dialect.maximumf(x, minint) else: raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") @@ -1206,6 +1272,13 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ lax.not_p: lambda ctx, x: ~x, }) +mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ + lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), + lax.not_p: _lower_fun( + lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + ), +}) + def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) @@ -1342,48 +1415,98 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - if y == 2: - return x * x - return NotImplementedError + if y != 2: + raise NotImplementedError + return _square_lowering_rule(ctx, x) + @register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - return x * x + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + x = _ensure_fa(x, x_aval.dtype) + return x * x + if jnp.issubdtype(x_aval.dtype, jnp.integer): + return arith_dialect.muli(x, x) + if jnp.issubdtype(x_aval.dtype, jnp.floating): + return arith_dialect.mulf(x, x) + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.rsqrt( + _ensure_ir_value(x, x_aval.dtype), fastmath=fastmath + ) + @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) def _tanh_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): - [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math)) +def _logistic(x): + return 1.0 / (1 + lax.exp(-x)) + + +mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( + _logistic, multiple_results=False +) +mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( + _lower_fun(_logistic, multiple_results=False) +) + @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) def _exp_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return a.exp(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) def _exp2_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return a.exp2(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) + + +@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) +def _log_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) @@ -1484,11 +1607,7 @@ def _debug_print_lowering_rule( ) elif len(ctx.avals_in) == 1: [arg] = args - @arg.foreach - def _(val, idx): - idx_fmt = ", ".join(["{}"] * len(idx)) - fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}") - mgpu.debug_print(fmt_str, *idx, val, uniform=False) + arg.debug_print(fmt) else: raise NotImplementedError( "debug_print only supports printing of scalar values, or a single array" @@ -1901,27 +2020,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): @register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule( + lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup +) def _bitcast_convert_type_lowering_rule( - ctx: LoweringRuleContext, operand, *, new_dtype + ctx: LoweringRuleContext, x, *, new_dtype ): - # TODO(petebu) Handle case where src and dst types have different bitwidths - [operand_aval] = ctx.avals_in - operand = _ensure_fa(operand, operand_aval.dtype) - src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype) + [x_aval] = ctx.avals_in + src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype) assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) if src_elem_type.width != dst_elem_type.width: raise NotImplementedError( - f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they" + f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they" " have different widths" ) + + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + x = _ensure_ir_value(x, x_aval.dtype) + return arith_dialect.bitcast( + ir.VectorType.get(x_aval.shape, dst_elem_type), x + ) + + x = _ensure_fa(x, x_aval.dtype) if ir.IntegerType.isinstance(dst_elem_type): output_is_signed = mgpu_utils.is_signed(new_dtype) else: output_is_signed = None return mgpu.FragmentedArray.bitcast( - operand, dst_elem_type, output_is_signed=output_is_signed + x, dst_elem_type, output_is_signed=output_is_signed ) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index c1ecd47bd..c2080b9c6 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -91,6 +91,7 @@ class BufferedRef: self.smem_ref.at[slot], # pytype: disable=unsupported-operands self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands predicate=predicate, + commit_group=False, ) @@ -299,6 +300,8 @@ def emit_pipeline( predicate=lax.bitwise_or(slices_changed, is_last_step), ) + gpu_primitives.commit_smem_to_gmem_group() + fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = lax.rem(fetch_step, max_concurrent_steps) @@ -344,6 +347,8 @@ def emit_pipeline( if bref.is_index_invariant: bref.copy_out(last_slot, last_indices, predicate=None) + gpu_primitives.commit_smem_to_gmem_group() + # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) @@ -578,6 +583,7 @@ def emit_pipeline_warp_specialized( bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), indices, predicate=slices_changed) + gpu_primitives.commit_smem_to_gmem_group() next_indices = _inc_grid_by_1(indices, grid) return (next_indices, new_store_slices, next_body_carry) init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) @@ -619,6 +625,8 @@ def emit_pipeline_warp_specialized( if bref.is_index_invariant: bref.copy_out(last_slot, last_indices, predicate=None) + gpu_primitives.commit_smem_to_gmem_group() + # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 547a20451..428e2925f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -86,6 +86,7 @@ def _copy_smem_to_gmem_lowering( src_transforms_treedef, dst_transforms_treedef, has_user_predicate, + commit_group, ): predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: @@ -106,6 +107,7 @@ def _copy_smem_to_gmem_lowering( src_ref=src, dst_ref=dst, predicate=predicate, + arrive=commit_group, **copy_params, ) return () @@ -119,7 +121,12 @@ def _copy_smem_to_gmem_lowering( assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") mgpu.dialect.async_store( - src, dst, indices, slice_lengths, predicate=predicate + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] ) return () @@ -174,7 +181,11 @@ def _extract_smem_copy_params(transforms): def copy_smem_to_gmem( - src: _Ref, dst: _Ref, predicate: jax.Array | None = None + src: _Ref, + dst: _Ref, + predicate: jax.Array | None = None, + *, + commit_group: bool = True, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -183,6 +194,9 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. + commit_group: If ``True``, this and any previously uncommitted copies + are committed to a group and can be awaited jointly via + :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -209,6 +223,7 @@ def copy_smem_to_gmem( src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, + commit_group=commit_group, ) return None @@ -475,6 +490,28 @@ def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only) +commit_group_p = jax_core.Primitive("commit_group") +commit_group_p.multiple_results = True + + +@commit_group_p.def_effectful_abstract_eval +def _commit_group_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +def _commit_group_lowering(ctx: lowering.LoweringRuleContext): + del ctx # Unused. + nvvm_dialect.cp_async_bulk_commit_group() + return () + + +def commit_smem_to_gmem_group() -> None: + """Commits all issued but uncommited SMEM->GMEM copies to a group.""" + commit_group_p.bind() + + # WGMMA on an accumulator reference wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 07a6fcf0a..9a850d847 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -695,14 +695,44 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, if precision is not None: raise ValueError("Only one of allow_tf32 and precision can be specified") precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST + dtype = jnp.promote_types(a.dtype, b.dtype) + out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32 return jax.lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=precision, - preferred_element_type=jnp.float32, + preferred_element_type=out_dtype, ) +reciprocal_p = jax_core.Primitive("reciprocal") + + +def reciprocal(x, *, approx=False): + return reciprocal_p.bind(x, approx=approx) + + +@reciprocal_p.def_abstract_eval +def _reciprocal_abstract_eval(x, *, approx): + del approx + return x + + +def _reciprocal_lowering_rule( + ctx: mlir.LoweringRuleContext, x, *, approx=False +): + def _reciprocal(x, *, approx=False): + if approx: + return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32) + return jnp.reciprocal(x) + + return mlir.lower_fun(_reciprocal, multiple_results=False)( + ctx, x, approx=approx + ) + + +mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule) + class PrintEffect(effects.Effect): __str__ = lambda self: "Print" diff --git a/jax/_src/path.py b/jax/_src/path.py index 8c46c5560..03a15e42e 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -12,35 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Protocol import logging +import os import pathlib -import importlib.util __all__ = ["Path"] - logger = logging.getLogger(__name__) +epath_installed: bool + +class PathProtocol(Protocol): + """A factory that creates a PurePath.""" + def __call__(self, *pathsegments: str | os.PathLike) -> pathlib.Path: + ... + +Path: PathProtocol # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin # pathlib and can only read/write to the local filesystem. -epath_installed = bool( - importlib.util.find_spec("etils") and - importlib.util.find_spec("etils.epath") -) -if epath_installed: - logger.debug("etils.epath found. Using etils.epath for file I/O.") - - def __dir__(): - return ["Path"] - - def __getattr__(name): - if name != "Path": - raise AttributeError(f"module '{__name__}' has no attribute '{name}") - - global Path - from etils import epath - Path = epath.Path - return Path -else: +try: + from etils import epath # type: ignore +except ImportError: logger.debug("etils.epath was not found. Using pathlib for file I/O.") Path = pathlib.Path + epath_installed = False +else: + logger.debug("etils.epath found. Using etils.epath for file I/O.") + # Ultimately, epath.Path implements pathlib.Path. See: + # https://github.com/google/etils/blob/2083f3d932a88d8a135ef57112cd1f9aff5d559e/etils/epath/abstract_path.py#L47 + Path = epath.Path + epath_installed = True diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index eb443f572..06892aa9f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -29,7 +29,6 @@ import warnings import numpy as np from jax._src import api -from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core @@ -199,10 +198,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, p.arg_names) + fun_name, fails, args_flat, 'jit', p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names @@ -359,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env, compiler_options_kvs=jit_info.compiler_options_kvs) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, @@ -546,8 +543,7 @@ class PjitParams(NamedTuple): def _infer_params_impl( fun: Callable, ji: PjitInfo, - pjit_mesh: mesh_lib.Mesh | None, - resource_env: mesh_lib.ResourceEnv | None, + ctx_mesh: mesh_lib.Mesh | None, dbg: core.DebugInfo, args: tuple[Any, ...], kwargs: dict[str, Any], @@ -559,8 +555,8 @@ def _infer_params_impl( raise ValueError( "pjit does not support kwargs when in_shardings is specified.") - if pjit_mesh is not None: - if (ji.backend or ji.device) and not pjit_mesh.empty: + if ctx_mesh is not None: + if (ji.backend or ji.device) and not ctx_mesh.empty: raise ValueError( "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") @@ -591,13 +587,12 @@ def _infer_params_impl( in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: - jit_name = 'pjit' if pjit_mesh is not None else 'jit' in_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) + _create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit') for x in ji.in_shardings_leaves) in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name) + _create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit') for x in ji.out_shardings_leaves) out_shardings_treedef = ji.out_shardings_treedef @@ -655,8 +650,8 @@ def _infer_params_impl( out_shardings=out_shardings_flat, in_layouts=in_layouts_flat, out_layouts=out_layouts_flat, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=fun_qual_name(flat_fun), keep_unused=ji.keep_unused, inline=ji.inline, @@ -686,38 +681,30 @@ def _infer_params_cached( jit_info: PjitInfo, signature: jax_jit.ArgumentSignature, in_avals: tuple[core.AbstractValue, ...], - pjit_mesh: mesh_lib.Mesh | None, - resource_env: mesh_lib.ResourceEnv | None, + ctx_mesh: mesh_lib.Mesh | None, ) -> InferParamsCacheEntry: return InferParamsCacheEntry() -def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together(): - if (not mesh_lib.thread_resources.env.physical_mesh.empty and - mesh_lib.get_concrete_mesh() is not None): - raise ValueError( - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed.') def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] -) -> tuple[PjitParams, list[Any]]: - disallow_use_mesh_and_legacy_mesh_ctx_mgr_together() + ) -> tuple[PjitParams, list[Any]]: if ji.use_resource_env: - # We need to fetch the mesh from inside the wrapped function, because - # meshes are dynamically scoped (i.e., with a context manager). - resource_env = mesh_lib.thread_resources.env - pjit_mesh = resource_env.physical_mesh - else: - resource_env = None - pjit_mesh = None + with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + return _infer_params_internal(fun, ji, args, kwargs) + return _infer_params_internal(fun, ji, args, kwargs) +def _infer_params_internal( + fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[PjitParams, list[Any]]: + ctx_mesh = mesh_lib.get_concrete_mesh() dbg = debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache - p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg, + p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=None) return p, p.consts + args_flat @@ -725,10 +712,11 @@ def _infer_params( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, ji.static_argnames, tree_util.default_registry) avals = _infer_input_type(fun, dbg, dynargs) - entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env) + entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) + if entry.pjit_params is None: p, args_flat = _infer_params_impl( - fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals) + fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) if p.attrs_tracked: # if attrs, don't popoulate the cache return p, p.consts + args_flat entry.pjit_params = p @@ -1619,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, - out_layouts, resource_env, donated_invars, name, keep_unused, inline, + out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs): in_shardings = _resolve_in_shardings(args, in_shardings) @@ -1627,8 +1615,8 @@ def _resolve_and_lower( jaxpr.in_avals) out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals) return _pjit_lower( - jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, name, keep_unused, inline, compiler_options_kvs, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) @@ -1637,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: @@ -1662,8 +1650,8 @@ def _pjit_call_impl_python( compiled = _resolve_and_lower( args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, - out_layouts=out_layouts, resource_env=resource_env, - donated_invars=donated_invars, name=name, keep_unused=keep_unused, + out_layouts=out_layouts, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), pgle_profiler=pgle_profiler, @@ -1694,7 +1682,7 @@ def _pjit_call_impl_python( @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, - out_layouts, resource_env, donated_invars, name, + out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to @@ -1708,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled, pgle_profiler = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, - out_layouts=out_layouts, resource_env=resource_env, - donated_invars=donated_invars, name=name, keep_unused=keep_unused, + out_layouts=out_layouts, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, @@ -1724,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr, f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) cache_key = pxla.JitGlobalCppCacheKeys( @@ -1733,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr, in_shardings_treedef=None, in_shardings_leaves=in_shardings, out_shardings_treedef=None, out_shardings_leaves=out_shardings, in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts, - use_resource_env=resource_env is not None) + out_layouts_treedef=None, out_layouts_leaves=out_layouts) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], cache_key, tree_util.dispatch_registry, pxla.cc_shard_arg, @@ -1749,8 +1736,8 @@ def _pjit_lower( out_shardings, in_layouts: pxla.MaybeLayout, out_layouts: pxla.MaybeLayout, - resource_env, donated_invars, + ctx_mesh, name: str, keep_unused: bool, inline: bool, @@ -1760,14 +1747,10 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") - if resource_env is not None: - mesh, api_name = resource_env.physical_mesh, 'pjit' - else: - mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' return pxla.lower_sharding_computation( - jaxpr, api_name, name, in_shardings, out_shardings, + jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), - keep_unused=keep_unused, context_mesh=mesh, + keep_unused=keep_unused, context_mesh=ctx_mesh, compiler_options_kvs=compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, @@ -1919,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, jaxpr: core.ClosedJaxpr, in_shardings, - out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, keep_unused, inline, compiler_options_kvs): + out_shardings, in_layouts, out_layouts, donated_invars, + ctx_mesh, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1929,7 +1912,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, tuple(effects), in_shardings, out_shardings, in_layouts, out_layouts, - api_name=('jit' if resource_env is None else 'pjit')) + api_name='jit') tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) @@ -1950,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in, dims_in: tuple[int, ...], jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) - if resource_env is not None: - mesh = resource_env.physical_mesh - else: - mesh = None - # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( - _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh, + aval.ndim) if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh, + aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) # TODO(yashkatariya): Figure out layouts should change under vmap. @@ -1982,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2005,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val): def _pjit_batcher_for_sharding( s: Sharding | UnspecifiedValue, - dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh, - ndim: int): + dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, + mesh, ndim: int): if isinstance(s, UnspecifiedValue): return s hlo_s = s._to_xla_hlo_sharding(ndim) @@ -2045,20 +2025,8 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): - if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): - jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) - mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) - primals_in = [*primals_in, *mut_primals] - tangents_in = [*tangents_in, *mut_tangents] - in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals) - in_layouts = (*in_layouts,) + (None,) * len(mut_primals) - donated_invars = (*donated_invars,) + (False,) * len(mut_primals) - - tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x - for x, a in zip(tangents_in, jaxpr.in_avals)] - is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( jaxpr, is_nz_tangents_in, instantiate=False) @@ -2074,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in, out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)), in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)), out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)), - resource_env=resource_env, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2091,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp def _pjit_linearization(nzs, *primals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) # constvars will become residuals. Move them to the end of the ordinary args. @@ -2107,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, out_shardings=_filter_zeros(nzs_out, out_shardings), in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts, out_layouts=_filter_zeros(nzs_out, out_layouts), - resource_env=resource_env, donated_invars=_filter_zeros(nzs, donated_invars) + res_donated, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2127,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, out_shardings=(*res_shardings, *out_shardings), in_layouts=in_layouts, out_layouts=(*res_layouts, *out_layouts), - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2143,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization def _pjit_partial_eval(trace: pe.JaxprTrace, *in_tracers, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, + in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] @@ -2210,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), out_shardings=known_out_shardings, in_layouts=keep_where(in_layouts, known_ins), - out_layouts=known_out_layouts, resource_env=resource_env, + out_layouts=known_out_layouts, donated_invars=keep_where(donated_invars, known_ins), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) @@ -2242,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, out_shardings=keep_where(out_shardings, unknown_outs), in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), out_layouts=keep_where(out_layouts, unknown_outs), - resource_env=resource_env, donated_invars=(keep_where(donated_invars, unknown_ins) + (False,) * num_residuals), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2330,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun, def _pjit_transpose(cts_in, *primals_in, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -2379,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in, out_shardings=transpose_out_shardings, in_layouts=transpose_in_layouts, out_layouts=transpose_out_layouts, - resource_env=resource_env, donated_invars=(False,) * len(primals_and_nz_cts_in), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2464,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, del params['out_layouts'] if not params['keep_unused']: del params['keep_unused'] - if (params['resource_env'] is None or - params['resource_env'].physical_mesh.empty): - del params['resource_env'] + if params['ctx_mesh'] is None or params['ctx_mesh'].empty: + del params['ctx_mesh'] if not params['compiler_options_kvs']: del params['compiler_options_kvs'] @@ -2536,6 +2504,11 @@ def with_sharding_constraint(x, shardings): This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see `Distributed arrays and automatic parallelization`_. + Inside of a jitted computation, with_sharding_constraint makes it possible to + constrain intermediate values to an uneven sharding. However, if such an + unevenly sharded value is output by the jitted computation, it will come out + as fully replicated, no matter the sharding annotation given. + Args: x: PyTree of jax.Arrays which will have their shardings constrained shardings: PyTree of sharding specifications. Valid values are the same as for @@ -2561,8 +2534,6 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint layouts", tree, layouts)) del layouts - disallow_use_mesh_and_legacy_mesh_ctx_mgr_together() - context_mesh = ( mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None else mesh_lib.thread_resources.env.physical_mesh) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 220342ce5..455a3b98c 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None: if _dtypes.float8_e8m0fnu is not None: _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 +if _dtypes.float4_e2m1fn is not None: + _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e3m4) if _dtypes.float8_e8m0fnu is not None: custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) + if _dtypes.float4_e2m1fn is not None: + custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 9c7528d90..a24736ccf 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents): return expi(x), jnp.exp(x) / x * x_dot -def _expn1(n: Array, x: Array) -> Array: +def _expn1(x: Array, n: Array) -> Array: # exponential integral En _c = _lax_const MACHEP = jnp.finfo(x.dtype).eps @@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array: return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"] -def _expn2(n: Array, x: Array) -> Array: +def _expn2(x: Array, n: Array) -> Array: # x > 1. _c = _lax_const BIG = _c(x, 1.44115188075855872e17) @@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array: return d["ans"] * jnp.exp(-x) -def _expn3(n: Array, x: Array) -> Array: +def _expn3(x: Array, n: Array) -> Array: # n >= 5000 _c = _lax_const one = _c(x, 1.0) @@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array: jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, - partial(_expn3, n), - partial(_expn2, n), - partial(_expn1, n), + _expn3, + _expn2, + _expn1, ] - ret = jnp.piecewise(x, conds, vals) + ret = jnp.piecewise(x, conds, vals, n=n) return ret diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index c05895c11..b1901f44f 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -305,7 +305,16 @@ class SetNameStackContextManager(contextlib.ContextDecorator): set_name_stack = SetNameStackContextManager -reset_name_stack = lambda: SetNameStackContextManager(NameStack()) + + +# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack +# implementation doesn't work. Luckily this context manager isn't called much so +# the performance shouldn't matter. See blame commit message for repro. +# reset_name_stack = lambda: SetNameStackContextManager(NameStack()) +@contextlib.contextmanager +def reset_name_stack() -> Iterator[None]: + with set_name_stack(NameStack()): + yield class TransformNameStackContextManager(contextlib.ContextDecorator): diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 7abaa3185..4b627c1cd 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -259,3 +259,26 @@ class NDIndexer: def transform_dtype(self, dtype): return dtype + + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + # If there are explicit axes, we don't support changing the shape, so we + # don't support int indexers and instead require all slices. + if (self.int_indexer_shape or + not all(isinstance(idx, Slice) for idx in self.indices)): + raise TypeError("sharded ref (array reference) can only be indexed by " + "slices, not integers") + # Moreover, only allow trivial slice(None) slices on explicitly sharded + # axes. Then the sharding stays the same. + _, slice_indexers, _ = unpack_ndindexer(self) + for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)): + if s is None: continue + if not (type(sl.start) is int and sl.start == 0 and + type(sl.size) is int and sl.size == d and + type(sl.stride) is int and sl.stride == 1): + raise ValueError("sharded ref (array reference) can only be sliced " + f"along unsharded axes, but ref of shape {self.shape} " + f"was sliced on axis {i}, which is sharded like {s}") + return sharding diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 2a8b8bcc9..6f7570a5f 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -206,6 +206,13 @@ def _dtype_after_transforming( return dtype +def _sharding_after_transforming(sharding, transforms): + for transform in transforms: + sharding = transform.transform_sharding(sharding) + assert sharding is not None + return sharding + + def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): transforms = tree_util.tree_unflatten(tree, args) @@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) - # TODO(yashkatariya): Transform the sharding too instead of setting it to - # None. - out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype, - sharding=core.get_cur_mesh_sharding()) + out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms) + out_aval = ref_aval.inner_aval.update( + shape=out_shape, dtype=out_dtype, sharding=out_sharding) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") @@ -437,6 +443,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, x_tangent, *_ = tangents + # if type(ref_tangent) is ad_util.Zero: + # raise Exception("you're an idiot") assert isinstance(ref_tangent.aval, AbstractRef) x_tangent = ad_util.instantiate(x_tangent) return (swap_p.bind(ref_primal, x_primal, *idx, **params), @@ -657,5 +665,14 @@ mlir.register_lowering( # === AD rules for mutable arrays === -ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g)) +def _mut_jvp(primals, tangents): + (init_val,), (init_val_dot,) = primals, tangents + primal_out = core.mutable_array_p.bind(init_val) + if type(init_val_dot) is ad_util.Zero: + tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval)) + else: + tangent_out = core.mutable_array_p.bind(init_val_dot) + return primal_out, tangent_out + +ad.primitive_jvps[core.mutable_array_p] = _mut_jvp ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g)) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 5efc7f1e0..057242f4c 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -119,6 +119,12 @@ class RefBitcaster: del dtype # Unused return self.dtype + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -166,6 +172,12 @@ class RefReshaper: del dtype # Unused return self.dtype + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + class Transform(Protocol): @@ -189,6 +201,10 @@ class Transform(Protocol): """ return dtype + def transform_sharding(self, sharding): + if all(p is None for p in sharding.spec): return sharding # no explicit axes + raise NotImplementedError + @dataclasses.dataclass class RefIndexer: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9f2bab2b4..18f7efa16 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1640,6 +1640,8 @@ class _LazyDtypes: float_dtypes += [_dtypes.float8_e4m3] if _dtypes.float8_e8m0fnu is not None: float_dtypes += [_dtypes.float8_e8m0fnu] + if _dtypes.float4_e2m1fn is not None: + float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/experimental/colocated_python/__init__.py b/jax/experimental/colocated_python/__init__.py index 5bd56e732..2d387b37c 100644 --- a/jax/experimental/colocated_python/__init__.py +++ b/jax/experimental/colocated_python/__init__.py @@ -14,7 +14,7 @@ """Colocated Python API.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 # pylint: disable=useless-import-alias from jax.experimental.colocated_python.api import ( diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 8e2883ea4..effca1fe7 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun( devices = specialization.devices - def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: + def lowered_fun(*args, **kwargs) -> jax.Array: result = info.fun(*args, **kwargs) result_leaves, out_treedef = tree_util.tree_flatten(result) out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index bfd5ec2e6..1ca29ab12 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -13,11 +13,9 @@ # limitations under the License. """Colocated Python serialization utilities.""" -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. - from __future__ import annotations +import base64 import collections import functools import io @@ -37,12 +35,6 @@ import numpy as np DeviceList = xc.DeviceList -# Hard-coded limit for serialized specs size. -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. -_MAX_SERIALIZED_SPECS_SIZE = 1048576 - - @jax.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" @@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any: def _make_specs_for_serialized_specs( devices: DeviceList, -) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]: +) -> api.ShapeDtypeStruct: """Makes output specs for serialized specs.""" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. mesh = jax.sharding.Mesh(tuple(devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - return ( - api.ShapeDtypeStruct( - shape=(), dtype=np.int32, sharding=replicated_sharding - ), - api.ShapeDtypeStruct( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - dtype=np.uint8, - sharding=replicated_sharding, - ), + return api.ShapeDtypeStruct( + shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore ) @@ -209,49 +192,49 @@ def _serialize_specs( specs_treedef: tree_util.PyTreeDef, specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: DeviceList, -) -> tuple[jax.Array, ...]: - """Serializes the output specs into a tuple of arrays. +) -> jax.Array: + """Serializes the output specs into a jax.Array of string type. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - s = _serialize((specs_treedef, specs_leaves)) - assert ( - len(s) <= _MAX_SERIALIZED_SPECS_SIZE - ), f"Too large serialized spec size: {len(s)}" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + if not hasattr(np.dtypes, "StringDType"): + raise TypeError( + "Serializing Colocated Python requires StringDType. Please use" + " numpy to 2.0.0 or later, or explicityly provide an output spec" + " function." + ) + + s_bytes = _serialize((specs_treedef, specs_leaves)) + s_str = base64.b64encode(s_bytes).decode("ascii") + s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore + + # TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making + # jax.Array via make_array_from_single_device_arrays. We should then use a + # sharding that spans all the execution devices - not just the addressable + # ones. + addressable_devices = devices.addressable_device_list + mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - len_array = jax.make_array_from_callback( - shape=(), - sharding=replicated_sharding, - data_callback=lambda _: np.array(len(s), dtype=np.int32), + + out_arrays = [ + jax.device_put(s_np_array, device) for device in addressable_devices + ] + return jax.make_array_from_single_device_arrays( + arrays=out_arrays, sharding=replicated_sharding, shape=(), ) - data_array = jax.make_array_from_callback( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - sharding=replicated_sharding, - data_callback=lambda _: np.frombuffer( - s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)), - dtype=np.uint8, - ), - ) - return len_array, data_array def _deserialize_specs( - serialized_specs: tuple[jax.Array, ...], + serialized_specs: jax.Array, ) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]: """Deserializes the specs from the serialized specs. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - len_array, data_array = serialized_specs - length = int(len_array.addressable_shards[0].data) - data = np.asarray(data_array.addressable_shards[0].data).tobytes() - return _deserialize(data[:length]) + data_array = serialized_specs.addressable_shards[0].data + data = base64.b64decode(data_array.item().encode("ascii")) + return _deserialize(data) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index d58a1bb0d..7f98ce433 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1541,10 +1541,12 @@ tf_not_yet_impl = [ "assert_consumed_value", "consume", "ragged_dot", + "ragged_dot_general", "cholesky_update", "symmetric_product", "from_edtype", "to_edtype", + "reciprocal", # Pallas TPU primitives "bitcast", "repeat", @@ -3571,8 +3573,8 @@ def _pjit(*args: TfVal, in_shardings: Sequence[sharding.Sharding], out_shardings: Sequence[sharding.Sharding], in_layouts, out_layouts, - resource_env: mesh.ResourceEnv, donated_invars, + ctx_mesh, name: str, keep_unused: bool, inline: bool, diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 3ca7a8571..368b47df4 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import math as mlir_math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf @@ -234,11 +235,6 @@ def _vector_load_op_lowering_rule( ir.ArrayAttr, vector_load_op.attributes["out_layouts"] ) - if not layouts.is_strided_fragmented_layout(out_layout_attr): - raise ValueError( - f"{vector_load_op} has an unsupported layout: {out_layout_attr}" - ) - for i in vector_load_op.indices: index_defining_op = i.owner.opview if ( @@ -254,9 +250,28 @@ def _vector_load_op_lowering_rule( element_type = vector_load_op.result.type.element_type is_signed = False if ir.IntegerType.isinstance(element_type) else None - fragmented_array = fa.FragmentedArray.load_strided( - vector_load_op.base, is_signed=is_signed - ) + if layouts.is_strided_fragmented_layout(out_layout_attr): + strided_layout = layouts.from_strided_fragmented_layout_attr( + out_layout_attr + ) + fragmented_array = fa.FragmentedArray.load_strided( + vector_load_op.base, + is_signed=is_signed, + vec_size=strided_layout.vec_size, + ) + elif layouts.is_wgmma_fragmented_layout(out_layout_attr): + layout = ir.MemRefType(vector_load_op.base.type).layout + swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) + transformed_ref = transform_memref(vector_load_op.base, transforms) + fragmented_array = fa.FragmentedArray.load_tiled( + transformed_ref, + swizzle=swizzle, + is_signed=is_signed + ) + else: + raise ValueError( + f"{vector_load_op} has an unsupported layout: {out_layout_attr}" + ) return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)] @@ -424,10 +439,77 @@ def _mgpu_async_store_op_lowering_rule( gmem_transform=transforms, uniform=True, predicate=ctx.single_thread_per_warpgroup_predicate, + arrive=store_op.commit_group, ) return [] +def _conversion_op_lowering_rule( + _: LoweringContext, + op: ir.OpView, + source_is_signed: bool | None, + target_is_signed: bool | None, +) -> Sequence[ir.Value]: + [in_layout] = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if in_layout != layout: + raise ValueError("Layout mismatch") + + target_ty = op.result.type.element_type # pytype: disable=attribute-error + operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed) + converted = operand.astype(target_ty, is_signed=target_is_signed) + return [_fragmented_array_to_ir(converted, op.result.type)] + + +for op, source_is_signed, target_is_signed in [ + (arith.ExtFOp, None, None), + (arith.ExtSIOp, True, True), + (arith.ExtUIOp, False, False), + (arith.FPToSIOp, None, True), + (arith.FPToUIOp, None, False), + (arith.SIToFPOp, True, None), + (arith.TruncFOp, None, None), + (arith.TruncIOp, False, False), + (arith.UIToFPOp, False, None), +]: + _lowerings[op.OPERATION_NAME] = functools.partial( + _conversion_op_lowering_rule, + source_is_signed=source_is_signed, + target_is_signed=target_is_signed, + ) + + +def _unary_op_lowering_rule( + _: LoweringContext, + op: Any, + impl: Callable[[fa.FragmentedArray], fa.FragmentedArray], + is_signed: bool | None = None, +) -> Sequence[ir.Value]: + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + kwargs = {} + if hasattr(op, "fastmath"): + kwargs = dict( + approx=op.fastmath == ir.Attribute.parse("#arith.fastmath") + ) + a = _fragmented_array_from_ir(op.operand, layout, is_signed) + return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)] + + +for op, impl, is_signed in [ + (mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None), + (mlir_math.ExpOp, fa.FragmentedArray.exp, None), + (mlir_math.Exp2Op, fa.FragmentedArray.exp2, None), + (mlir_math.LogOp, fa.FragmentedArray.log, None), + (mlir_math.TanhOp, fa.FragmentedArray.tanh, None), +]: + _lowerings[op.OPERATION_NAME] = functools.partial( + _unary_op_lowering_rule, impl=impl, is_signed=is_signed + ) + + def _binary_op_lowering_rule( _: LoweringContext, op: Any, @@ -525,6 +607,25 @@ def _cmpf_op_lowering_rule( return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] +@_register_lowering(arith.BitcastOp) +def _bitcast_op_lowering_rule( + _: LoweringContext, op: arith.BitcastOp +) -> Sequence[ir.Value]: + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + in_ = _fragmented_array_from_ir(op.in_, layout) + out_element_type = ir.VectorType(op.result.type).element_type + out = in_.bitcast( + out_element_type, + output_is_signed=False + if ir.IntegerType.isinstance(out_element_type) + else None, + ) + return [_fragmented_array_to_ir(out, op.result.type)] + + @_register_lowering(mgpu.WGMMAOp) def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp @@ -689,8 +790,12 @@ def _for_op_lowering_rule( new_args = (new_for_op.induction_variable, *recreated_carry) for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): old_carry.replace_all_uses_with(new_carry) - for op in ops_to_lower: + + for op in ops_to_lower: + with ir.InsertionPoint(op): ctx.lower_op(op) + + with ir.InsertionPoint(new_for_op.body): new_yield_operands = lower_carry(yield_op.operands) yield_op.erase() scf.yield_(new_yield_operands) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index a6772a575..d15cecbdc 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -53,13 +53,12 @@ def build_kernel( index = ir.IndexType.get() swizzle = 128 - tile_k = swizzle // 2 + swizzle_elems = tile_k = swizzle // 2 + tiling = (8, swizzle_elems) in_dtype = jnp.float16 k_loop_iter = k // tile_k max_concurrent_steps = min(max_concurrent_steps, k_loop_iter) - tma_tile_m = 128 - tma_tile_kn = 64 block_tile_m = tile_m block_tile_n = tile_n @@ -123,17 +122,14 @@ def build_kernel( src_ref=a, dst_ref=mgpu.memref_slice(a_smem, slot), gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)), + gmem_transform=mgpu.TileTransform(tiling), **common_args, ) ctx.async_copy( src_ref=b, dst_ref=mgpu.memref_slice(b_smem, slot), gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), - gmem_transform=( - mgpu.TileTransform((tma_tile_kn, tma_tile_kn)), - mgpu.TransposeTransform((1, 0, 2, 3)), - ), + gmem_transform=mgpu.TileTransform(tiling), **common_args, ) @@ -145,7 +141,7 @@ def build_kernel( tcgen05.mma( acc, mgpu.memref_slice(a_smem, slot), - mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)), + mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), a_swizzle=swizzle, b_swizzle=swizzle, accumulate=accumulate, @@ -172,26 +168,23 @@ def build_kernel( src_ref=d_smem, dst_ref=d, gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, 64)), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), swizzle=swizzle, ) ctx.await_async_copy(0) compute_buffers = ( jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), - (tma_tile_m, tma_tile_kn)), + mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling), jnp.float16), jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n), - (tma_tile_kn, tma_tile_kn)), + mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), jnp.float16), ) epilogue_buffer = jax.ShapeDtypeStruct( - mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)), + mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), jnp.float16) smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) - assert block_tile_m == 128 smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2ade6e848..a52eb329d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -22,6 +22,7 @@ import math from collections.abc import Callable from typing import Iterable, Protocol, Sequence, TypeVar +import itertools import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -115,6 +116,65 @@ class Tiling: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_nested_shape_strides( + self, + shape: tuple[tuple[int, ...], ...], + strides: tuple[tuple[int, ...], ...], + ) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: + """A fused version of `tile_shape` and `tile_strides` for nested shapes. + + By nested shape we mean that each logical dimension (i.e. each element of + shape/strides) is actually composed out of multiple physical dimensions. + For example, a row-major array of logical shape (128, 128) that is tiled + into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each + dim is split into two sub-dims) and nested strides of + ((2 * 64 * 64, 64), (64 * 64, 1)). + """ + if len(shape) != len(strides): + raise ValueError( + f"Shape {shape} and strides {strides} must have the same length" + ) + def fail_if(cond, shape=shape): # Capture shape now. + if cond: + raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") + for tile in self.tiles: + fail_if(len(tile) > len(shape)) + untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):] + untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):] + major_dim_shapes, major_dim_strides = [], [] + minor_dim_shapes, minor_dim_strides = [], [] + for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides): + major_dim_shape_rev, major_dim_stride_rev = [], [] + minor_dim_shape_rev, minor_dim_stride_rev = [], [] + for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True): + if d < t: # We will need to tile more dims + fail_if(t % d != 0) + t //= d + minor_dim_shape_rev.append(d) + minor_dim_stride_rev.append(s) + elif t != 1: # Last dim to tile! + fail_if(d % t != 0) + minor_dim_shape_rev.append(t) + minor_dim_stride_rev.append(s) + if d != t: # No need to insert singleton dims. + major_dim_shape_rev.append(d // t) + major_dim_stride_rev.append(s * t) + t = 1 + else: # Done tiling! + major_dim_shape_rev.append(d) + major_dim_stride_rev.append(s) + fail_if(t != 1) + major_dim_shapes.append(major_dim_shape_rev[::-1]) + minor_dim_shapes.append(minor_dim_shape_rev[::-1]) + major_dim_strides.append(major_dim_stride_rev[::-1]) + minor_dim_strides.append(minor_dim_stride_rev[::-1]) + shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) + strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) + return ( + tuple(tuple(d) if d else (1,) for d in shape), + tuple(tuple(d) if d else (1,) for d in strides), + ) + def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: for tile in self.tiles: untiled, tiled = indices[:-len(tile)], indices[-len(tile):] @@ -214,7 +274,7 @@ class TiledLayout: index = ir.IndexType.get() contig_strides = utils.get_contiguous_strides(shape) tile_strides = self.tiling.tile_strides(contig_strides) - dyn_tile_strides = [c(s, i32) for s in tile_strides] + dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]] warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides) lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides) dyn_offset = arith.addi(warp_offset, lane_offset) @@ -246,7 +306,12 @@ class TiledLayout: so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ - return self.tiling.tile_shape(self.base_tile_shape) + base_tile_shape = self.base_tile_shape + return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):] + + @functools.cached_property + def tiled_tiling_rank(self) -> int: + return len(self.tiled_tiling_shape) @property def vector_length(self) -> int: @@ -292,16 +357,12 @@ class TiledLayout: def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) - tiled_shape = tuple( - d if i == self.warp_dim else 1 - for i, d in enumerate_negative(self.tiled_tiling_shape) - ) - assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP + tiled_shape_rank = len(self.tiled_tiling_shape) warp_idx = arith.remui( arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), c(WARPS_IN_WARPGROUP, i32), ) - indices = [arith.constant(i32, 0)] * len(tiled_shape) + indices = [arith.constant(i32, 0)] * tiled_shape_rank indices[self.warp_dim] = warp_idx return tuple(indices) @@ -479,6 +540,33 @@ TILED_LAYOUT_WGMMA = TiledLayout( lane_dims=(-4, -3), vector_dim=-1, ) +# This tiled layout is similar to the one above. Above, each warp stores a 8x8 +# submatrix in the following way (we only show the first 4 rows for brevity): +# +# 0 0 1 1 2 2 3 3 +# 4 4 5 5 6 6 7 7 +# 8 8 9 9 10 10 11 11 +# 12 12 13 13 14 14 15 15 +# ... +# +# This tiled layout stores the same 8x8 submatrix in the following way: +# +# 0 4 1 5 2 6 3 7 +# 0 4 1 5 2 6 3 7 +# 8 12 9 13 10 14 11 15 +# 8 12 9 13 10 14 11 15 +# ... +# +# You can see that we have taken 2x2 submatrices from the above layout and +# transposed them. The assigment of lanes to elements is such that in both +# layouts the same two lanes map to a single 2x2 submatrix, making the transpose +# very cheap (one shuffle and permute suffices to change between those layouts). +WGMMA_TRANSPOSED_LAYOUT = TiledLayout( + Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))), + warp_dim=-10, + lane_dims=(-6, -3, -5), + vector_dim=-2, +) @jax.tree_util.register_pytree_node_class @dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) @@ -553,13 +641,22 @@ class FragmentedArray: raise NotImplementedError @classmethod - def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): + def load_strided( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + vec_size: int | None = None, + ): if not ir.MemRefType.isinstance(ref.type): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) shape = tuple(ref_ty.shape) - layout = WGStridedFragLayout.from_shaped_type(ref_ty) + if vec_size is None: + layout = WGStridedFragLayout.from_shaped_type(ref_ty) + else: + layout = WGStridedFragLayout(shape=shape, vec_size=vec_size) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) try: # Flattening the reference potentially produces simpler PTX but @@ -647,9 +744,35 @@ class FragmentedArray: At the moment, only conversions from ``WGSplatFragLayout`` are supported. """ + i32 = ir.IntegerType.get_signless(32) if self.layout == new_layout: return self shape = self.shape + if ( + self.layout == TILED_LAYOUT_WGMMA + and new_layout == WGMMA_TRANSPOSED_LAYOUT + and utils.bitwidth(self.mlir_dtype) == 16 + ): + is_even_row = arith.cmpi( + arith.CmpIPredicate.eq, + arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)), + c(0, i32), + ) + perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32)) + new_regs = [] + for reg in self.registers.flat: + reg_ty = reg.type + reg = utils.bitcast(reg, i32) + reg_shfl = utils.shfl_bfly(reg, 4) + new_reg = llvm.inline_asm( + i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r" + ) + new_regs.append(utils.bitcast(new_reg, reg_ty)) + return FragmentedArray( + _registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)), + _layout=new_layout, + _is_signed=self.is_signed, + ) if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0: tiled_layout = _tiled_wgmma_layout(shape) if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or ( @@ -966,6 +1089,24 @@ class FragmentedArray: return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) return self._pointwise(mlir_math.exp2) + def log(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx: + dtype = self.mlir_dtype + ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453)) + return self.log2(approx=True) * ln2 + return self._pointwise(mlir_math.log) + + def log2(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + if approx: + if not ir.F32Type.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32")) + return self._pointwise(mlir_math.log2) + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError @@ -1190,7 +1331,30 @@ class FragmentedArray: from_integer = ir.IntegerType.isinstance(cur_dtype) to_integer = ir.IntegerType.isinstance(new_dtype) if from_float and to_float: - if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + cur_ty_width = ir.FloatType(cur_dtype).width + new_ty_width = ir.FloatType(new_dtype).width + if cur_ty_width == new_ty_width: + # There is no instruction to perform conversions between two float types + # of the same width. Go through the next-larger standard type. + # TODO(bchetioui): support conversions between float types of width 8. + # Which larger type to pick will depend on the number of bits in the + # smallest exponent. + if cur_ty_width != 16: + raise NotImplementedError( + "Conversion between float types of width other than 16 not" + " supported" + ) + larger_ty = ir.F32Type.get() + match self.layout: + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): + shape = ir.VectorType(self.registers.flat[0].type).shape + upcast_ty = ir.VectorType.get(shape, larger_ty) + case WGMMARowFragLayout() | WGSplatFragLayout(): + upcast_ty = larger_ty + case _: + raise NotImplementedError(f"Unsupported layout {self.layout}") + convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x)) + elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: convert = arith.truncf else: convert = arith.extf @@ -1423,19 +1587,34 @@ class FragmentedArray: if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) - def store_untiled(self, ref: ir.Value): + def debug_print(self, fmt: str): + idx_fmt = ", ".join(["{}"] * len(self.shape)) + @self.foreach + def _(val, idx): + fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") + utils.debug_print(fmt_str, *idx, val, uniform=False) + + def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) + def vs_unsupported(): + if not vector_store: + raise NotImplementedError( + f"Can't use non-vector stores with layout {self.layout}" + ) + match self.layout: case WGMMAFragLayout(): self._store_untiled_wgmma(ref) case WGSplatFragLayout(): + vs_unsupported() self._store_untiled_splat(ref) case WGStridedFragLayout(): + vs_unsupported() self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref) + self._store_untiled_tiled(ref, vector_store=vector_store) case _: raise NotImplementedError(self.layout) @@ -1502,7 +1681,7 @@ class FragmentedArray: col = arith.addi(col_base, c(col_tile * 8 + col_idx)) memref.store(value, ref, [row, col]) - def _store_untiled_tiled(self, ref: ir.Value): + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") @@ -1510,7 +1689,7 @@ class FragmentedArray: layout = self.layout assert isinstance(layout, TiledLayout) ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if ref_strides[layout.vector_dim] != 1: + if vector_store and ref_strides[layout.vector_dim] != 1: raise NotImplementedError( "Can't use vector stores with non-unit minormost stride" ) @@ -1524,16 +1703,30 @@ class FragmentedArray: raise NotImplementedError(f"Unexpected ref space {ref_space}") ptr = utils.memref_ptr(ref, memory_space=memory_space) # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [arith.constant(i32, s) for s in strides] + dyn_strides = [ + arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] + ] warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) dyn_offset = arith.addi(warp_offset, lane_offset) ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) # All warp tile offsets are static and can be fused into the store. for tile_idx, reg in np.ndenumerate(self.registers): - lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(reg, reg_ptr) + if vector_store: + elems = [reg] + else: + index = ir.IndexType.get() + elems = [ + vector.extractelement(reg, position=c(i, index)) + for i in range(ir.VectorType(reg.type).shape[0]) + ] + for i, e in enumerate(elems): + tile_idx_local = list(tile_idx) + tile_idx_local[layout.vector_dim] += i + tile_idx_local = list(tile_idx_local) + lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) + reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) + llvm.store(e, reg_ptr) def store_tiled(self, ref, swizzle: int | None): match self.layout: @@ -1714,31 +1907,42 @@ class FragmentedArray: ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type if ref_ty.rank % 2: - raise ValueError("Tiled refence must have even rank") - ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:]) + raise ValueError("Tiled reference must have even rank") + ref_logical_rank = ref_ty.rank // 2 + ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:]) ref_tiling = Tiling((ref_tiling_shape,)) ref_strides, _ = ref_ty.get_strides_and_offset() if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: raise ValueError() - if len(layout.base_tile_shape) > len(ref_tiling_shape): - raise ValueError("Memory tiling must be a multiple of the register tiling") - ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):] - if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)): - raise ValueError( - f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the" - f" register tiling ({layout.base_tile_shape})" - ) + nested_ref_shape = tuple( + (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) + for i in range(ref_logical_rank) + ) + nested_ref_strides = tuple( + (ref_strides[i], ref_strides[i + ref_logical_rank]) + for i in range(ref_logical_rank) + ) + tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( + nested_ref_shape, nested_ref_strides + ) - elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) - tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + # We could technically handle this case, but it would be quite complicated. + # If tiling dimensions would have to be expanded into multiple, we'd have to + # adjust the dimension indices in layouts, including expanding some of them + # into multiple indices. Note that for non-tiling dims, we allow the shape + # to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx. + if any( + len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :] + ): + raise NotImplementedError("Memory and register tiling incompatible") + tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) + elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] lane_shape = [tiled_shape[d] for d in layout.lane_dims] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): tiled_shape[d] = 1 - full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) - full_layout = dataclasses.replace(layout, tiling=full_tiling) element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -1779,9 +1983,11 @@ class FragmentedArray: ) # All offsets are in units of transfer_dtype. - dyn_tiled_strides = [c(s) for s in transfer_tiled_strides] - lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) - warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides) + dyn_tiled_strides = [ + c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :] + ] + lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) + warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): raise ValueError("Tiled stores can be performed into SMEM") @@ -1809,10 +2015,23 @@ class FragmentedArray: reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype) offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle)) reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype) - reg_idxs = [ - tiling.tile_indices(full_tiling.untile_indices(idx)) - for idx in indices.tolist() - ] + # Here, registers are organized in an array with shape obtained by tiling + # the logical data bounds. But, the reference was tiled and so each + # logical tiled dimension can map to multiple dims in tiled_shape. + # The transform below maps this potentially higher-rank representation + # back to the lower-rank representation used by the register arrays. + def mem_idx_to_reg_idx(idx): + reg_tiled_idx = [] + base_idx = 0 + for dim_shape in tiled_nested_shape[:ref_logical_rank]: + dim_strides = utils.get_contiguous_strides(dim_shape) + dim_idxs = idx[base_idx:base_idx + len(dim_shape)] + base_idx += len(dim_shape) + reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides))) + # We should have fixed up all but the tiling dims. + assert base_idx == len(idx) - layout.tiled_tiling_rank + return (*reg_tiled_idx, *idx[base_idx:]) + reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 7a314a2ea..ce432f26d 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -430,8 +430,8 @@ class LaunchContext: gmem_ref, smem_ref = dst_ref, src_ref if barrier is not None: raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is not None: - raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + if arrive is None: + arrive = True # Commit this copy to the async group by default else: raise ValueError("Only SMEM <-> GMEM copies supported") # TODO(apaszke): This is a very approximate check. Improve it! @@ -683,7 +683,8 @@ class LaunchContext: nvvm.cp_async_bulk_tensor_global_shared_cta( tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) - nvvm.cp_async_bulk_commit_group() + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 5971cfb85..044e7537d 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -18,18 +18,23 @@ from collections.abc import Callable, Sequence import dataclasses import enum from functools import partial +import math from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import scf +from jax._src.lib.mlir.dialects import math as mlir_math from jax._src.lib.mlir.dialects import memref +from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +import numpy as np from . import fragmented_array as fa from . import inference_utils from . import layouts as layouts_lib +from . import utils + # mypy: ignore-errors @@ -192,22 +197,44 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: for op in [ - arith.AddIOp, arith.AddFOp, + arith.AddIOp, + arith.AddFOp, arith.AndIOp, + arith.BitcastOp, arith.CmpFOp, arith.CmpIOp, - arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp, + arith.ExtFOp, + arith.ExtSIOp, + arith.ExtUIOp, + arith.FPToSIOp, + arith.FPToUIOp, arith.MaximumFOp, - arith.MaxUIOp, arith.MaxSIOp, + arith.MaxUIOp, + arith.MaxSIOp, arith.MinimumFOp, - arith.MinUIOp, arith.MinSIOp, - arith.MulIOp, arith.MulFOp, + arith.MinUIOp, + arith.MinSIOp, + arith.MulIOp, + arith.MulFOp, arith.OrIOp, - arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp, - arith.RemUIOp, arith.RemSIOp, arith.RemFOp, - arith.SubIOp, arith.SubFOp, - arith.TruncFOp, arith.TruncIOp, + arith.FloorDivSIOp, + arith.DivUIOp, + arith.DivFOp, + arith.RemUIOp, + arith.RemSIOp, + arith.RemFOp, + arith.SIToFPOp, + arith.UIToFPOp, + arith.SubIOp, + arith.SubFOp, + arith.TruncFOp, + arith.TruncIOp, arith.XOrIOp, + mlir_math.ExpOp, + mlir_math.Exp2Op, + mlir_math.LogOp, + mlir_math.RsqrtOp, + mlir_math.TanhOp, vector.LoadOp, vector.StoreOp, ]: @@ -487,11 +514,36 @@ def infer_layout(module: ir.Module): # propagated. However, it is possible for some operations to remain # unannotated---for example, if there were no annotations on any operation in # the module at the start of this function. We annotate all the remaining ops - # that should be annotated with a strided fragmented layout. + # that should be annotated with a strided fragmented layout, whose vector size + # is derived from the narrowest type and vector size used in the program. We + # make sure to derive a single vector size in order to avoid relayouts at + # lowering time. + default_vector_size = math.inf + + def update_default_vector_size(op: ir.OpView): + nonlocal default_vector_size + for v in list(op.operands) + list(op.results): + if ir.VectorType.isinstance(v.type): + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + for op in module.body: + traverse_op(op, update_default_vector_size) + + if default_vector_size is None: # Nothing to annotate. + return + def to_default_layout(ty: ir.Type) -> ir.Attribute | None: if not ir.VectorType.isinstance(ty): return None - layout = fa.WGStridedFragLayout.from_shaped_type(ty) + layout = fa.WGStridedFragLayout( + shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size + ) return layouts_lib.to_strided_fragmented_layout_attr(layout) def set_default_layout(op: ir.OpView): diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py new file mode 100644 index 000000000..81f6af1a9 --- /dev/null +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -0,0 +1,223 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import math + +from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import llvm + +from . import utils + +# mypy: ignore-errors + +def tiled_memref_shape(ref: ir.Value): + """Returns the 2D untiled shape and element type of a tiled 4D memref.""" + ref_ty = ir.MemRefType(ref.type) + if ref_ty.rank != 4: + raise ValueError(f"Expected a 4D memref, got: {ref_ty}") + logical_shape = ( + ref_ty.shape[0] * ref_ty.shape[2], ref_ty.shape[1] * ref_ty.shape[3] + ) + return logical_shape, ref_ty.element_type + + +class Dim(enum.Enum): + K = enum.auto() + MN = enum.auto() + + +def create_descriptor( + ref: ir.Value, + swizzle: int, + group_size: tuple[int, int], # Instruction group size on each operand dim. + logical_k_major: bool, # False for LHS, True for RHS. + # Soft deprecated. Use small tiling instead. + large_tile: tuple[int, int] | None = None, +): + ref_ty = ir.MemRefType(ref.type) + element_bytewidth = utils.bytewidth(ref_ty.element_type) + swizzle_elems = swizzle // element_bytewidth + ref_strides, _ = ref_ty.get_strides_and_offset() + ref_byte_strides = [s * element_bytewidth for s in ref_strides] + mn_large_tile = k_large_tile = None + if logical_k_major: + _, mn_tiles, k_tiling, mn_tiling = ref_ty.shape + k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = ( + ref_byte_strides + ) + k_group_size, mn_group_size = group_size + if large_tile is not None: + k_large_tile, mn_large_tile = large_tile + else: + mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape + mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = ( + ref_byte_strides + ) + mn_group_size, k_group_size = group_size + if large_tile is not None: + mn_large_tile, k_large_tile = large_tile + + IGNORED = 0 + MMA_ATOM_ROWS = 8 + MMA_BYTEWIDTH_K = 32 + mma_width_k = MMA_BYTEWIDTH_K // element_bytewidth + # As far as I can tell (which does not seem to fully align with the way MMA is + # documented in PTX docs), MMA expects the data to be tiled into matrices + # of shape 8 x swizzle_elems, with swizzle_elems dim being the fastest + # changing. I call this submatrix an MMA atom. + # + # The role of the SMEM descriptor is to specify the striding pattern between + # those atoms. The fastest changing dimension is called the "leading" + # dimension and it specifies the stride between consecutive atoms that share + # the same coordinate along that dim. The slower dimension is called a + # "stride" dimension. + if ( + large_tile is not None + and k_large_tile == k_tiling + and (mn_large_tile == mn_tiling or mn_tiles == 1 and mn_tiling < mn_large_tile) + # There are configurations where large tiles are same size as small ones. + # We use the small path since it has fewer restrictions. + and set(large_tile) != {MMA_ATOM_ROWS, swizzle_elems} + ): # Large tiles. + if ( + k_tiling_stride == element_bytewidth + and mn_tiling_stride == k_tiling * element_bytewidth + ): + fastest_dim = Dim.K + leading_byte_offset = IGNORED # TC assumes K to be contiguous here. + # MMA atoms in a group are contiguous, so we increment by the MMA atom + # size. However, we only have one level of striding, and so if the group + # size exceeds a single large tile (and there is more than one tile) then + # that tiled dimension must be contiguous after tiles or else we would + # need another striding level. + if ( + mn_tiles > 1 + and mn_group_size > mn_tiling + and mn_tile_stride != math.prod(large_tile) * element_bytewidth + ): + raise ValueError( + "MMA layout with large tiles that is K-fastest only supports" + " multiple MN tiles when the tiled MN dimension is a contiguous" + " stack of tiles " + f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)} * {element_bytewidth})" + ) + stride_byte_offset = MMA_ATOM_ROWS * swizzle + desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. + elif ( + k_tiling_stride == k_tiling * element_bytewidth + and mn_tiling_stride == element_bytewidth + ): + if k_large_tile != mn_large_tile: + raise ValueError( + "MMA layout with large tiles that is MN-fastest is only supported" + " when the tiling is square" + ) + fastest_dim = Dim.MN + # Next swizzle atom with the same K coordinate is in the next MN tile. + leading_byte_offset = mn_tile_stride + # MMA atoms in a group are contiguous and a group does not exceed a tile. + assert k_large_tile == k_group_size + stride_byte_offset = MMA_ATOM_ROWS * swizzle + # Each row is swizzle bytes wide, and we read mma_width_k rows at a time. + assert mn_large_tile == swizzle // element_bytewidth + desc_k_stride = mma_width_k * swizzle + else: + raise ValueError("MMA tiles must be contiguous") + else: # Small tiles. + if k_tiling_stride > mn_tiling_stride: + slower_tiling, faster_tiling = k_tiling, mn_tiling + else: + faster_tiling, slower_tiling = k_tiling, mn_tiling + if slower_tiling != MMA_ATOM_ROWS or faster_tiling != swizzle_elems: + raise ValueError( + f"Tiling should be ({MMA_ATOM_ROWS}, swizzle_elems) where" + f" swizzle_elems = swizzle // bytewidth(dtype) (= {swizzle} //" + f" {element_bytewidth} = {swizzle_elems}), but got ({slower_tiling}," + f" {faster_tiling})" + ) + if k_tiling_stride == element_bytewidth and mn_tiling_stride == swizzle: + fastest_dim = Dim.K + leading_byte_offset = IGNORED # TC assumes K to be contiguous here. + stride_byte_offset = mn_tile_stride + desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. + elif k_tiling_stride == swizzle and mn_tiling_stride == element_bytewidth: + fastest_dim = Dim.MN + leading_byte_offset = mn_tile_stride + stride_byte_offset = k_tile_stride + k_tiles_per_mma = mma_width_k // MMA_ATOM_ROWS + desc_k_stride = k_tile_stride * k_tiles_per_mma + else: + raise ValueError("MMA tiles must be contiguous") + desc_base = encode_descriptor( + ref, + leading_byte_offset=leading_byte_offset, + stride_byte_offset=stride_byte_offset, + swizzle=swizzle, + ) + + mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling) + assert not rem + mn_group_stride = mn_tile_stride * mn_tiles_per_group + k_tiles_per_group, rem = divmod(k_group_size, k_tiling) + assert not rem + k_group_stride = k_tile_stride * k_tiles_per_group + + return ( + (desc_base, desc_k_stride), + (mn_group_stride, k_group_stride), + fastest_dim, + ) + + +def encode_addr(x: int): + result = (x & 0x3FFFF) >> 4 + if result << 4 != x: + raise ValueError(f"Cannot encode value in an MMA descriptor: {x}") + return result + + +def encode_descriptor( + memref_arg, + leading_byte_offset: int, + stride_byte_offset: int, + swizzle: int | mgpu_dialect.SwizzlingMode | None, + const_init: int = 0, +): + i64 = ir.IntegerType.get_signless(64) + ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, 3)) + c = lambda x: arith.constant(i64, x) + if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: + swizzle_encoding = 0 + elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: + swizzle_encoding = 1 + elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle: + swizzle_encoding = 2 + elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle: + swizzle_encoding = 3 + else: + raise NotImplementedError(swizzle) + encoded_base_addr = llvm.lshr(llvm.and_(ptr_val, c(0x3FFFF)), c(4)) + # We ignore the offset + desc_const = ( + const_init + | (encode_addr(leading_byte_offset) << 16) + | (encode_addr(stride_byte_offset) << 32) + ) + desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const)) + desc = llvm.or_(encoded_base_addr, desc) + return desc diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 1933610b6..7a349f50c 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -18,7 +18,6 @@ from __future__ import annotations import dataclasses import math -from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm @@ -27,7 +26,7 @@ import numpy as np from . import utils from . import fragmented_array as fa -from . import _wgmma +from . import mma_utils from .launch_context import LaunchContext # MyPy does a terrible job with the MLIR API. @@ -37,21 +36,6 @@ from .launch_context import LaunchContext TMEM_ROWS = 128 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 -def create_smem_descriptor( - memref_arg, - leading_byte_offset: int, - stride_byte_offset: int, - swizzle: int | mgpu_dialect.SwizzlingMode | None, -): - return _wgmma.create_descriptor( - memref_arg, - leading_byte_offset, - stride_byte_offset, - swizzle, - memory_space=3, - const_init=TCGEN05_SMEM_DESCRIPTOR_BIT, - ) - def create_instr_descriptor( m: int, n: int, @@ -100,70 +84,126 @@ def mma( collective: bool = False, ): i64 = ir.IntegerType.get_signless(64) + if isinstance(accumulate, bool): + accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) + if a_swizzle != b_swizzle: + raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") + swizzle = a_swizzle + num_cta = 2 if collective else 1 + + # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(a.type): raise ValueError(f"A must be a memref, got {a.type}") if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") - if a_swizzle != b_swizzle: - raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") - if isinstance(accumulate, bool): - accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) + (k, n), element_type = mma_utils.tiled_memref_shape(b) + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if k != k2: + raise ValueError( + "MMA requires A and B to have the same contraction dimension (K)," + f" got: {k2} and {k}" + ) + if element_type != element_type2: + raise ValueError( + "MMA requires A and B to have the same element type, got:" + f" {element_type2} and {element_type}" + ) + if d.shape != (m, n * num_cta): + raise ValueError( + f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" + ) + f32 = ir.F32Type.get() + if element_type == f32 or element_type == ir.BF16Type.get(): + if d.dtype != f32: + raise ValueError( + f"MMA with element type {element_type} only supports accumulators" + f" of type f32, but got: {d.dtype}" + ) + elif element_type == ir.F16Type.get(): + if d.dtype != element_type and d.dtype != f32: + raise ValueError( + "MMA with element type f16 only supports accumulators of type f32" + f" or f16, but got: {d.dtype}" + ) - m_group_size = d.layout.elements_in_tile[0] - if m_group_size != 128: + # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, + # instructions must be issued in groups of the same width as the swizzle. + m_group_elems = d.layout.elements_in_tile[0] + if m_group_elems != 128: raise NotImplementedError("Only 128-row accumulators supported for now") - - ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - (a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride), - mma_params, - ) = _wgmma._validate_mma( - a, - b, - a_swizzle, - m_group_size=m_group_size, - descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT, - ) - n_group_size = n // n_groups - if n > 512: - raise ValueError(f"N is too big: at most 512 is supported, but got {n}") - num_cta = 2 if collective else 1 + k_group_elems = swizzle // utils.bytewidth(element_type) + if n % 8: + raise ValueError(f"N must be a multiple of 8, got: {n}") + elif n > 256 and n != 512: + raise ValueError("Only N below 256 or N=512 are supported") if num_cta == 2 and n > 256: raise NotImplementedError( "N is too big for collective MMA. Only up to 256 is supported." ) + n_group_elems = min(n, 256) + if m % m_group_elems: + raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") + if k % k_group_elems: + raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}") + if n % n_group_elems: + raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}") + m_groups = m // m_group_elems + k_groups = k // k_group_elems + n_groups = n // n_group_elems + # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. + wgmma_element_type = ( + ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type + ) - # TODO(apaszke): Verify that the cluster shape matches the expectation of - # collective MMA. - expected_acc_shape = (m, n * num_cta) - if d.shape != expected_acc_shape: - raise ValueError( - f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}" - ) + # Step 3. Compute the operand descriptors. + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + ( + (b_desc_base, b_k_instr_stride), + (b_n_group_stride, b_k_group_stride), + b_fastest, + ) = mma_utils.create_descriptor( + b, + swizzle=swizzle, + group_size=(k_group_elems, n_group_elems), + logical_k_major=True, + ) + # Step 4. Issue the instructions. true = arith.constant(ir.IntegerType.get_signless(1), 1) for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64)) + a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride - b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64)) + b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) if m_groups != 1: raise NotImplementedError("D needs to be sliced") acc = accumulate if ki == 0 else true _do_mma( d.slice( - slice(None), utils.ds(ni * n_group_size, n_group_size) + slice(None), utils.ds(ni * n_group_elems, n_group_elems) ).address, a_mk, b_nk, d_type=ir.F32Type.get(), - m=m_group_size, + m=m_group_elems, + n=n_group_elems, collective=collective, - **mma_params, + a_transpose=a_fastest != mma_utils.Dim.K, + b_transpose=b_fastest != mma_utils.Dim.K, + a_k_stride=a_k_instr_stride, + b_k_stride=b_k_instr_stride, accumulate=acc, + swizzle=swizzle, + element_type=wgmma_element_type, ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 191f0fcfc..f90f7ff08 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1172,4 +1172,32 @@ def getelementptr( def dyn_dot(x, y): + assert len(x) == len(y) return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) + + +def shfl_bfly(x: ir.Value, distance: int | ir.Value): + i32 = ir.IntegerType.get_signless(32) + if isinstance(distance, int): + distance = c(distance, i32) + assert x.type == i32 + return nvvm.shfl_sync( + i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, + ) + + +def bitcast(x: ir.Value, new_type: ir.Type): + if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): + new_type = ir.IntegerType(new_type) + x_ty = ir.VectorType(x.type) + assert new_type.width == bitwidth(x_ty.element_type) * math.prod(x_ty.shape) + i0 = arith.ConstantOp.create_index(0) + return vector.extractelement( + vector.bitcast(ir.VectorType.get((1,), new_type), x), position=i0 + ) + if ir.IntegerType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + new_type = ir.VectorType(new_type) + x_ty = ir.IntegerType(x.type) + assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) + return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) + raise ValueError(f"Can't bitcast {x.type} to {new_type}") diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 33e1ce92e..f3edbe639 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -14,13 +14,10 @@ # ============================================================================== import dataclasses -import enum import functools import itertools -from typing import Any import jax -from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm @@ -29,6 +26,7 @@ from jaxlib.mlir.dialects import vector import numpy as np from . import fragmented_array as fa +from . import mma_utils from . import utils # mypy: ignore-errors @@ -84,60 +82,6 @@ class WGMMAAccumulator: return cls(_value=value[0], _sync=False) -def wgmma_encode(x: int): - result = (x & 0x3FFFF) >> 4 - if result << 4 != x: - raise ValueError(f"Cannot encode value in a WGMMA descriptor: {x}") - return result - - -def llvm_add(x, y): - return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none) - - -def create_descriptor( - memref_arg, - leading_byte_offset: int, - stride_byte_offset: int, - swizzle: int | mgpu_dialect.SwizzlingMode | None, - memory_space: int | None = None, - const_init: int = 0, -): - i64 = ir.IntegerType.get_signless(64) - ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space)) - if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: - swizzle_encoding = 0 - elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: - swizzle_encoding = 1 - elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle: - swizzle_encoding = 2 - elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle: - swizzle_encoding = 3 - else: - raise NotImplementedError(swizzle) - encoded_base_addr = llvm.LShrOp( - llvm.AndOp(ptr_val, c(0x3FFFF, i64)).result, c(4, i64) - ) - # We ignore the offset - desc_const = ( - const_init - | (wgmma_encode(leading_byte_offset) << 16) - | (wgmma_encode(stride_byte_offset) << 32) - ) - desc = llvm.or_( - arith.shli(c(swizzle_encoding, i64), c(62, i64)), c(desc_const, i64) - ) - desc = llvm.or_(encoded_base_addr.result, desc) - return desc - - -def _unpack_i32(vec_ty, r): - i32 = ir.IntegerType.get_signless(32) - return vector.bitcast( - vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) - ) - - def _supported_wgmma_types(dtype, abtype) -> bool: input_types_are = lambda ty: ty.isinstance(abtype) if ir.F32Type.isinstance(dtype): @@ -271,14 +215,14 @@ def wgmma_m64( a_args = [_as_i32_reg(v) for v in a_slice.registers.flat] else: if i > 0: - a = llvm_add( + a = _llvm_add( a, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)), ) a_args = [a] # Advance the B descriptor. if i > 0: - b_descriptor = llvm_add( + b_descriptor = _llvm_add( b_descriptor, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)), ) @@ -297,262 +241,6 @@ def wgmma_m64( return to_acc_vec_regs(acc_regs) -class WGMMALayout(enum.Enum): - ROW_MAJOR = enum.auto() - COL_MAJOR = enum.auto() - - -def _validate_mma( - a: Any, - b: ir.Value, - swizzle: int, - m_group_size: int, # The M used by a single instruction. - descriptor_const_init: int = 0, -): - # We need swizzle >= 32 to ensure that our K tiling is larger than the MMA - # instruction's K width. - if swizzle < 32: - raise ValueError(f"Unsupported swizzle: {swizzle}") - - # Get A type. - if a_in_smem := isinstance(a, ir.Value): - if not ir.MemRefType.isinstance(a.type): - raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}") - a_ty = ir.MemRefType(a.type) - a_element_type = a_ty.element_type - a_shape = tuple(a_ty.shape) - if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("A must be in workgroup memory when it's a reference") - if len(a_shape) != 4: - raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}") - elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"): - a_element_type = a.mlir_dtype - a_shape = a.shape - else: - raise NotImplementedError(f"Unsupported A type: {type(a)}") - - # Get B type (always a reference). - b_ty = ir.MemRefType(b.type) - if b_ty.rank != 4: - raise ValueError(f"B must be 4D, got rank {b_ty.rank}") - - # Veirfy element types and compute the tiling. - if (element_type := a_element_type) != b_ty.element_type: - raise ValueError( - f"A and B must have the same element type, got: {a_element_type} and" - f" {b_ty.element_type}" - ) - supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()} - if element_type not in supported_types: - raise ValueError(a_element_type) - element_bytewidth = bytewidth(element_type) - swizzle_elems = swizzle // element_bytewidth - - # Verify the shape and strides of B are as expected. - b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape - k = b_k_tiles * b_k_tiling - n = n_tiles * n_tiling - - b_strides, _ = b_ty.get_strides_and_offset() - b_byte_strides = [s * element_bytewidth for s in b_strides] - b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides - if ( - b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth - and n_tiles != 1 # When there's only one tile, we never jump between them - ): - raise ValueError("B tiles must be contiguous along the N dimension") - if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest - b_order = WGMMALayout.ROW_MAJOR - # This first case (n_tiles == 1) is to allow the somewhat weird case of - # loading a small amount of N-fastest data, that needs to be padded to a - # larger tile due to swizzle. In this case we allow slicing the big tile - # before WGMMA to avoid unnecessary compute on padding. - if n_tiles == 1: - if n_tiling % 8: - raise ValueError("N tile size must be a multiple of 8") - elif n_tiling != swizzle_elems: - raise ValueError( - "Row major RHS (N-fastest) requires the N tile size to be equal to" - f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}" - ) - if b_k_tiling not in {8, swizzle_elems}: - raise ValueError( - "Row major RHS (N-fastest) requires the K tile size to be either" - f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}" - ) - elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest - b_order = WGMMALayout.COL_MAJOR - if b_k_tiling != swizzle_elems: - raise ValueError( - "Column major RHS (K-fastest) requires the K tile size to be equal" - f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}" - ) - # See the explanation in the N-fastest case when n_tiles == 1. - if n_tiles == 1: - if n_tiling % 8: - raise ValueError("N tile size must be a multiple of 8") - elif n_tiling not in {8, swizzle_elems}: - raise ValueError( - "Column major RHS (K-fastest) requires the N tile size to be either" - f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}" - ) - else: - raise ValueError(b_byte_strides) - - if n > 256 and n % 256: - raise ValueError( - f"N group size must be a multiple of 256 when larger than 256, got: {n}" - ) - k_group_size = swizzle_elems - n_group_size = min(n, 256) - b_k_tiles_per_group = k_group_size // b_k_tiling - b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group - n_tiles_per_group = n_group_size // n_tiling - b_n_group_stride = b_n_byte_stride * n_tiles_per_group - - # Verify the shape and strides of A are as expected. - if not a_in_smem: - m = a_shape[0] - a_order = a_m_group_stride = a_k_group_stride = None - else: - a_ty = ir.MemRefType(a.type) - m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape - m = m_tiles * m_tiling - # TODO(apaszke): I'm not actually convinced that we need this check. - if m_tiling != m_group_size: - raise ValueError( - f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}" - ) - if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k: - raise ValueError(a_ty.shape) - a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() - a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [ - s * element_bytewidth for s in a_strides - ] - if a_tile_byte_strides == [swizzle, element_bytewidth]: - a_order = WGMMALayout.ROW_MAJOR - elif a_tile_byte_strides == [element_bytewidth, swizzle]: - a_order = WGMMALayout.COL_MAJOR - else: - raise ValueError(a_strides) - if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems: - # Not sure what the layout is like, since the tiles aren't square. - raise NotImplementedError - a_m_tiles_per_group = m_group_size // m_tiling - a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group - a_k_tiles_per_group = k_group_size // a_k_tiling - a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group - - b_k_fastest = b_order == WGMMALayout.COL_MAJOR - a_k_fastest = a_order == WGMMALayout.ROW_MAJOR - # This is the number of rows until consecutive repeats of the swizzle pattern. - swizzle_pattern_rows = swizzle // 16 - # A swizzle atom is a 2D matrix with the dimensions below. - swizzle_atom_bytes = swizzle_pattern_rows * 128 - - # Here "leading" refers to the fastest changing dimension. There are two - # strides we have to define per value: - # Leading byte offset (LBO) - # K-fastest: ignored - # MN-fastest: stride between consecutive swizzle atoms that share the same - # K coordinate. - # Stride byte offset (SBO) - # As far as I can tell this is just the offset between two consecutive - # swizzle atoms along the non-leading dimension. - IGNORED = 0 - a_desc_fields = dict( - # I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it - # is documented in the PTX docs, and my best explanation for the other - # case is that the instruction has a fixed shape and so it does not care - # about strides. It's possible that it's an artifact of the fact that we - # use tiling of 64. - leading_byte_offset=IGNORED, - stride_byte_offset=swizzle_atom_bytes, - swizzle=swizzle, - memory_space=3, - ) - # If B is N-fastest, all swizzle atoms within a tile share the same N - # coordinate, so we simply take the stride between consecutive N tiles. - # If B is K-fastest, all swizzle atoms within a tile share the same K - # coordinate, which forces us to lay out the tiles in N-fastest order or else - # they would have uneven strides. - b_desc_fields = dict( - leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride, - # N tiles are contiguous, so the next N swizzle atom follows immediately. - # K tiles are not contiguous, so we take the stride between them. - stride_byte_offset=swizzle_atom_bytes - if b_k_fastest or b_k_tiling == swizzle_elems - else b_k_byte_stride, - swizzle=swizzle, - memory_space=3, - ) - # The K strides indicate the stride between the consecutive places where all - # coordinates are 0 except for K being incremented by the instruction width. - # If an input is K-fastest, we increment the descriptor by 32 bytes, since - # that is the K-width of all MMA instructions. - if b_k_fastest: - b_k_wgmma_stride = 32 - elif b_k_tiling == swizzle_elems: - # When B is N-fastest and we use the large square tiling, the relevant - # slices all fall within the first tile. A single MMA instruction for 16-bit - # types reads a subtile of shape 16x(swizzle bytes), giving us the necessary - # expression. - assert n_tiling == swizzle_elems or n_tiles == 1 - b_k_wgmma_stride = swizzle * 16 - else: - # If we use the small non-square tiling and N-fastest layout, each tile only - # contains a single swizzle atom with the K coordinate. But, each tile has - # 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles. - b_k_wgmma_stride = b_k_byte_stride * 2 - wgmma_params = dict( - a_transpose=not a_k_fastest, - b_transpose=not b_k_fastest, - # TODO(apaszke): This explanation is quite bad. We should better figure - # out how to do LHS transposes. - # We only support swizzle=128 for M-fastest A. In this case the tile is - # swizzle x 64 (= swizzle elems) and so we just take a quarter of its size. - a_k_stride=32 if a_k_fastest else swizzle * 16, - b_k_stride=b_k_wgmma_stride, - swizzle=swizzle, - n=n_group_size, - element_type=ir.FloatTF32Type.get() - if ir.F32Type.isinstance(element_type) - else element_type, - ) - if not a_in_smem: - wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None - a_desc_base = None - else: - a_desc_base = create_descriptor( - a, **a_desc_fields, const_init=descriptor_const_init - ) - b_desc_base = create_descriptor( - b, **b_desc_fields, const_init=descriptor_const_init - ) - - if m % m_group_size: - raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}") - m_groups = m // m_group_size - if k % k_group_size: - raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}") - k_groups = k // k_group_size - if n % n_group_size: - raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}") - n_groups = n // n_group_size - - return ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - # Group strides are always in bytes! - (a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride), - wgmma_params, - ) - - -# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer -# transpositions from memref strides. def wgmma( acc: WGMMAAccumulator, a: fa.FragmentedArray | ir.Value, @@ -570,61 +258,129 @@ def wgmma( The refs must be contiguous or be contiguous except for having their two minor dimensions swapped. """ - a_in_regs = isinstance(a, fa.FragmentedArray) - if not a_in_regs and not ir.MemRefType.isinstance(a.type): - raise ValueError(f"Unsupported A type: {type(a)}") + # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") - - m_group_size = 64 # Hopper has a fixed M instruction shape. - - ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - (a_m_group_stride, a_k_group_stride, b_k_group_stride, _), - wgmma_params, - ) = _validate_mma(a, b, swizzle, m_group_size=m_group_size) - - if n_groups > 1: - raise ValueError("N is too big for WGMMA. Only up to 256 is supported.") - - if a_in_regs: + (k, n), element_type = mma_utils.tiled_memref_shape(b) + if a_in_regs := isinstance(a, fa.FragmentedArray): + m, k2 = a.shape + element_type2 = a.mlir_dtype if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): raise ValueError( f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}" ) - if a.shape[0] % m_group_size: - raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}") - a_m_group_stride = a_k_group_stride = None - + elif ir.MemRefType.isinstance(a.type): + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + else: + raise ValueError(f"Unsupported A type: {type(a)}") + if k != k2: + raise ValueError( + "WGMMA requires A and B to have the same contraction dimension (K)," + f" got: {k2} and {k}" + ) + if element_type != element_type2: + raise ValueError( + "WGMMA requires A and B to have the same element type, got:" + f" {element_type2} and {element_type}" + ) if acc.value.shape != (m, n): raise ValueError( f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}" ) + f32 = ir.F32Type.get() + if element_type == f32 or element_type == ir.BF16Type.get(): + if acc.value.mlir_dtype != f32: + raise ValueError( + f"WGMMA with element type {element_type} only supports accumulators" + f" of type f32, but got: {acc.value.mlir_dtype}" + ) + elif element_type == ir.F16Type.get(): + if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32: + raise ValueError( + "WGMMA with element type f16 only supports accumulators of type f32" + f" or f16, but got: {acc.value.mlir_dtype}" + ) + # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, + # instructions must be issued in groups of the same width as the swizzle. + m_group_elems = 64 # Hopper has a fixed M instruction shape. + k_group_elems = swizzle // utils.bytewidth(element_type) + if n > 256 or n % 8: + raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}") + n_group_elems = n # We assume only one N group below. + if m % m_group_elems: + raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") + if k % k_group_elems: + raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}") + m_groups = m // m_group_elems + k_groups = k // k_group_elems + # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. + wgmma_element_type = ( + ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type + ) + + # Step 3. Compute the operand descriptors. + if a_in_regs: + a_desc_base = a_m_group_stride = a_k_group_stride = None + a_instr_params = dict(a_transpose=None, a_k_stride=None) + else: + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + large_tile=(m_group_elems, k_group_elems), + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K, + a_k_stride=a_k_instr_stride) + ( + (b_desc_base, b_k_instr_stride), + (b_n_group_stride, b_k_group_stride), + b_fastest, + ) = mma_utils.create_descriptor( + b, + swizzle=swizzle, + large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. + group_size=(k_group_elems, n_group_elems), + logical_k_major=True, + ) + del b_n_group_stride # We only support one N group. + + # Step 4. Issue the instructions. if a_in_regs: a = wgmma_fence(a) # Make sure the registers are ready. i64 = ir.IntegerType.get_signless(64) new_acc_regs = acc.value.registers.copy() - k_group_size = k // k_groups for mi in range(m_groups): for ki in range(k_groups): if a_in_regs: a_mk = a[ - mi * m_group_size : (mi + 1) * m_group_size, - ki * k_group_size : (ki + 1) * k_group_size, + mi * m_group_elems : (mi + 1) * m_group_elems, + ki * k_group_elems : (ki + 1) * k_group_elems, ] else: - a_mk = llvm_add( - a_desc_base, - c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64), + a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride + a_mk = _llvm_add( + a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64), ) - b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64)) + b_k = _llvm_add( + b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64) + ) new_acc_regs[mi : mi + 1] = wgmma_m64( - new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params + new_acc_regs[mi : mi + 1], + a_mk, + b_k, + swizzle=swizzle, + n=n_group_elems, + element_type=wgmma_element_type, + b_transpose=b_fastest != mma_utils.Dim.K, + b_k_stride=b_k_instr_stride, + **a_instr_params, ) return WGMMAAccumulator( _value=fa.FragmentedArray( @@ -668,3 +424,14 @@ def _as_i32_reg(v): def _lc(x): i32 = ir.IntegerType.get_signless(32) return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result + + +def _llvm_add(x, y): + return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none) + + +def _unpack_i32(vec_ty, r): + i32 = ir.IntegerType.get_signless(32) + return vector.bitcast( + vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) + ) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 963be0381..1e0abacfc 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -53,6 +53,7 @@ from jax._src.pallas.primitives import max_contiguous as max_contiguous from jax._src.pallas.primitives import multiple_of as multiple_of from jax._src.pallas.primitives import num_programs as num_programs from jax._src.pallas.primitives import program_id as program_id +from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index e13dd11ed..631b4f720 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -35,6 +35,7 @@ from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arri from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem +from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py new file mode 100644 index 000000000..30cb20733 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -0,0 +1,721 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TPU-Friendly Ragged Paged Attention kernel. + +This kernel offers a highly optimized implementation of ragged paged attention, +specifically designed for TPU and compatible with a wide range of model +specifications. It supports mixed prefill and decoding, enhancing throughput +during inference. +""" + +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sem, + page_indices_ref, # i32[max_num_seqs, pages_per_seq] + offset, # [seq_idx, kv_pages_start] + ): + self._vmem_buf = vmem_buf + seq_id, kv_pages_start = offset + self._async_copies = [ + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]], + vmem_buf.at[i], + sem, + ) + for i in range(vmem_buf.shape[0]) + ] + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def wait(self): + for async_copy in self._async_copies: + async_copy.wait() + return self._vmem_buf + + +def ref_ragged_paged_attention( + queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1], + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, +): + _, _, num_kv_heads, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0 + num_query_per_kv = num_q_heads // num_kv_heads + outputs = [] + for i in range(num_seqs[0]): + q_start = cu_q_lens[i] + q_end = cu_q_lens[i + 1] + q_len = q_end - q_start + kv_len = kv_lens[i] + indices = page_indices[i] + q = queries[q_start:q_end] + k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) + attn *= sm_scale + q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( + jnp.int32, attn.shape, 1 + ) + kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) + attn += jnp.where(q_span < kv_span, mask_value, 0.0) + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + outputs.append(out) + + return jnp.concatenate(outputs, axis=0) + + +# Expect to run these checkes during runtime. +def validate_inputs_on_runtime( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32[1] +): + check_inputs_shapes( + q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + max_num_batched_tokens = q.shape[0] + page_size = k_pages.shape[1] + max_num_seqs, pages_per_seq = page_indices.shape + if num_seqs[0] > max_num_seqs: + raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") + max_kv_len = jnp.max(kv_lens) + min_pages_per_seq = ceil_div(max_kv_len, page_size) + if pages_per_seq < min_pages_per_seq: + raise ValueError( + f"{pages_per_seq=} must be greater or equal to" + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}." + ) + if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: + raise ValueError( + f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" + f" {max_num_batched_tokens=}." + ) + for i in range(num_seqs[0]): + q_len = cu_q_lens[i + 1] - cu_q_lens[i] + kv_len = kv_lens[i] + if q_len > kv_len: + raise ValueError( + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." + ) + + +# Expect to run these checks during compile time. +def check_inputs_shapes( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32[1] +): + _, num_q_heads, head_dim = q.shape + _, _, num_kv_heads, head_dim_k = k_pages.shape + max_num_seqs, _ = page_indices.shape + if num_seqs.shape != (1,): + raise ValueError(f"{num_seqs.shape=} must be (1,)") + if k_pages.shape != v_pages.shape: + raise ValueError( + f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." + ) + if head_dim_k != head_dim: + raise ValueError( + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." + ) + if kv_lens.shape != (max_num_seqs,): + raise ValueError( + f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`." + ) + if cu_q_lens.shape != (max_num_seqs + 1,): + raise ValueError( + f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" + " `max_num_seqs` is `page_indices.shape[0]`." + ) + if ( + kv_lens.dtype != jnp.int32 + or page_indices.dtype != jnp.int32 + or cu_q_lens.dtype != jnp.int32 + ): + raise ValueError( + "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," + f" {cu_q_lens.dtype=}." + ) + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + + +def ragged_paged_attention_kernel( + # Prefetch + kv_lens_ref, # [max_num_seqs] + page_indices_ref, # [max_num_seqs, pages_per_seq] + cu_q_lens_ref, # [max_num_seqs + 1] + seq_buf_idx_ref, + # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs. + num_seqs_ref, + # Input + q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + # Output + o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + # Scratch + k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sems, # [2, 2] + l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + *, + sm_scale: float, + mask_value: float, +): + num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + num_seqs = num_seqs_ref[0] + _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + num_kv_per_blk = num_kv_pages_per_blk * page_size + num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk + heads_blk_idx, q_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + ) + num_heads_blks = pl.num_programs(0) + init_seq_idx = seq_buf_idx_ref[0] + init_buf_idx = seq_buf_idx_ref[1] + q_len_start = q_blk_idx * num_q_per_blk + q_len_end = q_len_start + num_q_per_blk + + def create_kv_async_copy_descriptors( + heads_blk_idx, seq_idx, kv_blk_idx, buf_idx + ): + offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) + heads_start = heads_blk_idx * num_kv_heads_per_blk + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], + k_bufs.at[buf_idx], + sems.at[buf_idx, 0], + page_indices_ref, + offset, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], + v_bufs.at[buf_idx], + sems.at[buf_idx, 1], + page_indices_ref, + offset, + ) + return async_copy_k, async_copy_v + + # TODO(jevinjiang): Add these to Mosaic: + # 1. Support arbitrary strided load/store for any dtype. + # 2. Support arbitrary strided load/store for any last dimension. + def strided_load_kv(ref, start, step): + if ref.dtype == jnp.float32: + return ref[start::step, :] + packing = get_dtype_packing(ref.dtype) + assert ref.dtype == jnp.bfloat16 + assert step % packing == 0 + b_start = start // packing + b_offset = start % packing + b_step = step // packing + b_ref = ref.bitcast(jnp.int32) + b = b_ref[b_start::b_step, :] + bw = 32 // packing + b = jnp.right_shift(b, bw * b_offset) + b = jnp.left_shift(b, bw * (packing - 1)) + return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + + def fold_on_2nd_minor(vec): + assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 + assert len(vec.shape) >= 2 + last_dim = vec.shape[-1] + packing = get_dtype_packing(vec.dtype) + if vec.shape[-2] % packing != 0: + vec = vec.astype(jnp.float32) + return vec.reshape(-1, last_dim) + + @pl.when(heads_blk_idx + q_blk_idx == 0) + def prefetch_first_kv_blk(): + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, init_seq_idx, 0, init_buf_idx + ) + async_copy_k.start() + async_copy_v.start() + + def is_cur_q_blk_needed(q_states): + done, cur_seq_idx, _ = q_states + return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + + def compute_with_cur_q_blk(q_states): + done, cur_seq_idx, cur_buf_idx = q_states + q_start = cu_q_lens_ref[cur_seq_idx] + q_end = cu_q_lens_ref[cur_seq_idx + 1] + q_len = q_end - q_start + kv_len = kv_lens_ref[cur_seq_idx] + + def get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ): + next_kv_blk_idx = kv_blk_idx + 1 + is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len + next_kv_blk_idx = lax.select( + is_last_kv_blk, + 0, + next_kv_blk_idx, + ) + is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end + next_seq_idx = lax.select( + is_last_kv_blk, + lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx), + cur_seq_idx, + ) + is_last_seq = next_seq_idx == num_seqs + next_seq_idx = lax.select( + is_last_seq, + 0, + next_seq_idx, + ) + next_heads_blk_idx = lax.select( + is_last_seq, + heads_blk_idx + 1, + heads_blk_idx, + ) + next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0) + return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + + def flash_attention( + q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim] + k, # [num_kv_per_blk, head_dim] + v, # [num_kv_per_blk, head_dim] + head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + *, + kv_blk_idx, + ): + assert q.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + head_dim, + ) + assert k.shape == ( + num_kv_per_blk, + head_dim, + ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" + assert v.shape == (num_kv_per_blk, head_dim) + assert head_m_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_l_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_o_ref.shape == ( + num_q_per_blk, + num_q_heads_per_kv_head, + head_dim, + ) + kv_len_start = kv_blk_idx * num_kv_per_blk + + def masked_store(ref, val, start, end, group=1): + iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group + mask = jnp.logical_and(iota >= start, iota < end) + pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + + qk = ( + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) + * sm_scale + ) + store_start = jnp.maximum(q_start - q_len_start, 0) + store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) + + @pl.when(kv_blk_idx == 0) + def init_scratch_ref(): + masked_store( + head_m_ref, + jnp.full_like(head_m_ref, -jnp.inf), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_l_ref, + jnp.zeros_like(head_l_ref), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_o_ref, + jnp.zeros_like(head_o_ref), + store_start, + store_end, + ) + + row_ids = ( + (kv_len - q_len) + + q_len_start + - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) + // num_q_heads_per_kv_head + ) + col_ids = kv_len_start + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 1, + ) + causal_mask = row_ids < col_ids + qk += jnp.where(causal_mask, mask_value, 0.0) + m_curr = jnp.max(qk, axis=1, keepdims=True) + s_curr = jnp.exp(qk - m_curr) + qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) + lm_store_shape = head_m_ref.shape + m_curr = jnp.broadcast_to(m_curr, lm_store_shape) + l_curr = jnp.broadcast_to( + s_curr.sum(axis=1, keepdims=True), lm_store_shape + ) + m_prev = head_m_ref[...] + l_prev = head_l_ref[...] + m_next = jnp.maximum(m_prev, m_curr) + masked_store( + head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head + ) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_alpha = alpha * l_prev + l_next = l_alpha + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + masked_store( + head_l_ref, + l_next_safe, + store_start, + store_end, + num_q_heads_per_kv_head, + ) + + def broadcast_to_shape(arr, shape): + if arr.shape == shape: + return arr + assert len(arr.shape) == len(shape) + assert arr.shape[0] == shape[0] + assert shape[1] % arr.shape[1] == 0 + # no-op concatenation. + return jnp.concatenate( + [arr for _ in range(shape[1] // arr.shape[1])], axis=1 + ) + + o_curr = head_o_ref[...].reshape(-1, head_dim) + l_alpha = broadcast_to_shape(l_alpha, qkv.shape) + beta = broadcast_to_shape(beta, qkv.shape) + l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) + out = lax.div( + l_alpha * o_curr + beta * qkv, + l_next_safe, + ).astype(head_o_ref.dtype) + masked_store( + head_o_ref, + out.reshape(head_o_ref.shape), + store_start, + store_end, + ) + + def is_valid_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, _ = kv_states + return kv_blk_idx * num_kv_per_blk < kv_len + + def compute_with_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, cur_buf_idx = kv_states + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( + get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) + ) + + @pl.when(next_heads_blk_idx < num_heads_blks) + def prefetch_next_kv_blk(): + # TODO(jevinjiang): reuse the same buffer if it is already prefetched! + # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and + # DMA to fixed size buffer! + next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + ) + next_async_copy_k.start() + next_async_copy_v.start() + + cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) + kv_to_load_shape = ( + num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + head_dim, + ) + k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) + v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) + for kv_head_idx in range(num_kv_heads_per_blk): + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handlig for packed type that can start at + # unaligned position! + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) + k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) + v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + kv_blk_idx=kv_blk_idx, + ) + return kv_blk_idx + 1, next_buf_idx + + _, next_buf_idx = lax.while_loop( + is_valid_kv_blk_in_cur_seq, + compute_with_kv_blk_in_cur_seq, + (0, cur_buf_idx), # (kv_blk_idx, buf_idx) + ) + next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) + done = lax.select(q_end < q_len_end, done, 1) + return done, next_seq_idx, next_buf_idx + + _, seq_idx, buf_idx = lax.while_loop( + is_cur_q_blk_needed, + compute_with_cur_q_blk, + (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx) + ) + # Reset seq_idx for next kv_heads_blk if run out of seqs! + seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) + seq_buf_idx_ref[1] = buf_idx + + +def ceil_div(a, b): + assert b != 0 + return (a + b - 1) // b + + +def get_dtype_packing(dtype): + if dtype == jnp.float32: + return 1 + if dtype == jnp.bfloat16: + return 2 + if dtype == jnp.int8: + return 4 + if dtype == jnp.int4: + return 8 + raise ValueError(f"Not implemented: unsupported {dtype=}") + + +def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): + q_packing = get_dtype_packing(q_dtype) + kv_packing = get_dtype_packing(kv_dtype) + + def can_be_xla_fully_tiled(x, packing): + if x % packing != 0: + return False + x //= packing + return x in (1, 2, 4, 8) or x % 8 == 0 + + # TODO(jevinjiang): support unaligned number of heads! + if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + raise ValueError( + f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + ) + assert num_q_heads % num_kv_heads == 0 + ratio = num_q_heads // num_kv_heads + # TODO(jevinjiang): we can choose smaller tiling for packed type if large + # second minor tiling is not on. + max_kv_tiling = 8 * kv_packing + min_kv_heads = ( + max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + ) + min_q_heads = min_kv_heads * ratio + if can_be_xla_fully_tiled(min_q_heads, q_packing): + return min_q_heads, min_kv_heads + return num_q_heads, num_kv_heads + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "mask_value", + "num_kv_pages_per_block", + "num_queries_per_block", + "vmem_limit_bytes", + ], +) +def ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + # TODO(jevinjiang): create a write_to_kv_cache kernel! + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, + num_kv_pages_per_block: int = 16, + num_queries_per_block: int = 128, + vmem_limit_bytes: int | None = None, +): + """Ragged paged attention that supports mixed prefill and decode. + + Args: + q: concatenated all sequences' queries. + k_pages: paged K cache. Normally in HBM. + v_pages: paged V cache. Normally in HBM. + kv_lens: padded kv lengths. Only the first num_seqs values are valid. + page_indices: the first index indicates which page to use in the kv cache + for each sequence. Only the first num_seqs values are valid. + cu_q_lens: the cumulative sum of the effective query lengths. Similar to + kv_lens, only the first num_seqs+1 values are valid. + num_seqs: the dynamic number of sequences. + sm_scale: the softmax scale which will be applied to the Q@K^T. + mask_value: mask value for causal mask. + num_kv_pages_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + vmem_limit_bytes: the vmem limit for the pallas kernel. + + Returns: + The output of the attention. + """ + check_inputs_shapes( + q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + _, num_q_heads, head_dim = q.shape + _, page_size, num_kv_heads, _ = k_pages.shape + num_q_per_blk = num_queries_per_block + num_kv_pages_per_blk = num_kv_pages_per_block + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_kv_heads, q.dtype, k_pages.dtype + ) + assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 + num_heads_blks = num_q_heads // num_q_heads_per_blk + grid = (num_heads_blks, num_q_blks) + + def q_index_map(heads_blk_idx, q_blk_idx, *_): + return (q_blk_idx, heads_blk_idx, 0) + + q_block_spec = pl.BlockSpec( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + q_index_map, + ) + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + out_specs = q_block_spec + lm_scratch = pltpu.VMEM( + # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support + # unaligned slicing! + (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), + jnp.float32, + ) + double_buf_scratch = pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_blk, + page_size, + num_kv_heads_per_blk, + head_dim, + ), + k_pages.dtype, + ) + scratch_shapes = [ + double_buf_scratch, # k_bufs + double_buf_scratch, # v_bufs + pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + lm_scratch, # l_ref + lm_scratch, # m_ref + ] + scalar_prefetches = ( + kv_lens, + page_indices, + cu_q_lens, + jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx + num_seqs, + ) + kernel = pl.pallas_call( + functools.partial( + ragged_paged_attention_kernel, + sm_scale=sm_scale, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=( + "arbitrary", + "arbitrary", + ), + vmem_limit_bytes=vmem_limit_bytes, + ), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + name="ragged_paged_attention_kernel", + ) + # TODO(jevinjiang): Use f32 acc scratch for output! So we only need + # to transfer output with desired dtype back to HBM. + return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index f80e5c501..74edcb9cd 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -55,6 +55,57 @@ for prim in it.chain( roofline.register_standard_roofline(prim) +def _unary_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + unfused_flops=x.size, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + +roofline.register_roofline(lax.abs_p)(_unary_p_roofline) +roofline.register_roofline(lax.acos_p)(_unary_p_roofline) +roofline.register_roofline(lax.asin_p)(_unary_p_roofline) +roofline.register_roofline(lax.atan_p)(_unary_p_roofline) +roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.ceil_p)(_unary_p_roofline) +roofline.register_roofline(lax.conj_p)(_unary_p_roofline) +roofline.register_roofline(lax.cos_p)(_unary_p_roofline) +roofline.register_roofline(lax.cosh_p)(_unary_p_roofline) +roofline.register_roofline(lax.exp_p)(_unary_p_roofline) +roofline.register_roofline(lax.expm1_p)(_unary_p_roofline) +roofline.register_roofline(lax.floor_p)(_unary_p_roofline) +roofline.register_roofline(lax.imag_p)(_unary_p_roofline) +roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline) +roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline) +roofline.register_roofline(lax.log_p)(_unary_p_roofline) +roofline.register_roofline(lax.log1p_p)(_unary_p_roofline) +roofline.register_roofline(lax.logistic_p)(_unary_p_roofline) +roofline.register_roofline(lax.neg_p)(_unary_p_roofline) +roofline.register_roofline(lax.not_p)(_unary_p_roofline) +roofline.register_roofline(lax.real_p)(_unary_p_roofline) +roofline.register_roofline(lax.round_p)(_unary_p_roofline) +roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.sign_p)(_unary_p_roofline) +roofline.register_roofline(lax.sin_p)(_unary_p_roofline) +roofline.register_roofline(lax.sinh_p)(_unary_p_roofline) +roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.square_p)(_unary_p_roofline) +roofline.register_roofline(lax.tan_p)(_unary_p_roofline) +roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline) +roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline) +roofline.register_roofline(special.digamma_p)(_unary_p_roofline) +roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline) +roofline.register_roofline(special.erf_p)(_unary_p_roofline) +roofline.register_roofline(special.erfc_p)(_unary_p_roofline) +roofline.register_roofline(special.lgamma_p)(_unary_p_roofline) + def _binary_p_roofline( ctx: roofline.RooflineRuleContext, *args, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 0f161c074..0477e1c90 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -544,6 +544,8 @@ def _shard_map_staging( return out_tracers pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging +# TODO add underscore version, for direct-linearize to consume + def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval @@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, in_avals_, in_nodes) - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto - ) + manual_axes = frozenset(mesh.axis_names) - auto + new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with _extend_axis_env(mesh, auto): out_nodes_, tokens_out = mlir.call_lowering( @@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool, def _match(mesh, check_rep, pspec, x): src = P(mesh.axis_names) - # TODO put back (?) needed for rep checking in eager? for now test rewrite return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) def _rem_singleton(x): return jnp.squeeze(x, axis=0) @@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace): __slots__ = ("mesh", "auto", "check", "context_mesh") mesh: Mesh + auto: frozenset[AxisName] check: bool context_mesh: AbstractMesh @@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace): if isinstance(val, ShardMapTracer): return val.val, val.rep elif isinstance(val, Tracer): - raise Exception("Shouldn't have any non-shard_map tracers") + raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) return val_, None @@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, out_names_thunk, check_rep, rewrite, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) - f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, - f.debug_info) + f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) + res_names = _all_newly_manual_mesh_names(mesh, auto, trace) - @as_hashable_function(closure=(linearize_outs_thunk)) + @as_hashable_function(closure=linearize_outs_thunk) def primal_out_names_thunk(): - residual_avals, _, _ = linearize_outs_thunk() + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() out_names = out_names_thunk() - # This is incorrect so we set `check_rep=False` as we do in the JVP rule. - return (*({0: all_names} for _ in residual_avals), *out_names) + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). + return (*({0: res_names} for _ in range(num_res_out)), *out_names) primal_params = dict( mesh=mesh, in_names=in_names, out_names_thunk=primal_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) all_primal_results = shard_map_p.bind_with_trace( - trace.parent_trace, (f_primal,) + tuple(primals), primal_params) - residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = all_primal_results[:num_residuals] - primals_out = all_primal_results[num_residuals:] - args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals] - lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) + trace.parent_trace, (f_primal, *primals), primal_params) + residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_primal_results[:num_res_out] + primals_out = all_primal_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) + args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None + for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] + with core.extend_axis_env_nd(mesh.shape.items()): + lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() - new_in_names = (*({0: all_names} for _ in residual_avals), + residual_names = [in_names[f1] if f1 is not None else + out_names[f2] if f2 is not None else + {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] + new_in_names = (*residual_names, *({} for _ in range(len(env))), *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),) + new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz) @as_hashable_function(closure=(new_out_names)) def tangent_out_names_thunk(): return new_out_names @@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, out_names_thunk=tangent_out_names_thunk, check_rep=False, rewrite=rewrite, auto=auto) + # TODO TODO don't round-trip def f_tangent(*args): - residuals = args[:num_residuals] - nz_tangents = args[num_residuals:] - return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents) + return core.eval_jaxpr(lin_jaxpr, (), *args) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace, (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), - *residuals, *nz_tangents_in), tangent_params) + *residuals, *env, *nz_tangents_in), tangent_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] @@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize @lu.transformation2 def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): ans = f(*args, **kwargs) - residual_avals, _, _ = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = ans[:num_residuals] - primals = ans[num_residuals:] - residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in residuals) - return residuals + primals + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + residuals = ans[:num_res_out] + primals = ans[num_res_out:] + residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in residuals] + return *residuals, *primals @lu.transformation2 def _promote_scalar_residuals(f: Callable, *args, **kwargs): @@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule( _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) newvar = core.gensym() - params_known, params_staged, all_names = _pe_custom_params( + params_known, params_staged, res_names = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval)) + residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval)) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # prune inputs to jaxpr_known according to unks_in mesh = params_known['mesh'] auto = params_known['auto'] - all_names = _all_newly_manual_mesh_names(mesh, auto) + res_names_ = _all_newly_manual_mesh_names(mesh, auto) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: all_names}] * sum(which) + out_names_known = out_names_known + [{0: res_names_}] * sum(which) new_params_known = dict(params_known, in_names=tuple(in_names_known), out_names=tuple(out_names_known)) @@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, _, in_names_staged = partition_list(inst_in, params_staged['in_names']) res_names = [in_names_known[f1] if f1 is not None else out_names_known[f2] if f2 is not None else - {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] + {0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)] in_names_staged = res_names + in_names_staged _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged, all_names + return new_params_known, new_params_staged, res_names_ # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( @@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd( return tuple(name for name in mesh.axis_names if name not in spmd_names and name not in auto) -# TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_newly_manual_mesh_names( mesh: Mesh, auto: frozenset[AxisName], trace=None ) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - spmd_names = axis_env.spmd_axis_names - axis_sizes = axis_env.axis_sizes - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto and name not in axis_sizes) + if not (ctx_mesh := get_abstract_mesh()).empty: + del mesh + already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ())) + return tuple(name for name in ctx_mesh.axis_names + if name not in auto | already_manual_names) + else: + # TODO(mattjj): remove this mechanism when we revise mesh scopes + axis_env = core.get_axis_env() + vmap_spmd_names = set(axis_env.spmd_axis_names) + already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names + return tuple(name for name in mesh.axis_names + if name not in auto | vmap_spmd_names | already_manual_names) # DCE diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 36d84cb0d..7739af029 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,8 +19,18 @@ import math import jax from jax._src import core +from jax._src import ffi from jax._src import util from jax._src.typing import Array +from jax._src.lib import gpu_sparse + + +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index bd72850bc..582fdf411 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, name, + in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index a26d15c14..4e376fb66 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -17,6 +17,7 @@ from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, + RaggedDotDimensionNumbers as RaggedDotDimensionNumbers, Precision as Precision, PrecisionLike as PrecisionLike, DotAlgorithm as DotAlgorithm, @@ -158,6 +159,7 @@ from jax._src.lax.lax import ( pow as pow, pow_p as pow_p, ragged_dot as ragged_dot, + ragged_dot_general as ragged_dot_general, real as real, real_p as real_p, reciprocal as reciprocal, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d563483a2..cb291bdca 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import ( vdot as vdot, ) +from jax._src.numpy.util import ( + ndim as ndim, + shape as shape, + size as size, +) + from jax._src.numpy.window_functions import ( bartlett as bartlett, blackman as blackman, @@ -279,15 +285,12 @@ from numpy import ( integer as integer, iterable as iterable, nan as nan, - ndim as ndim, newaxis as newaxis, number as number, object_ as object_, pi as pi, save as save, savez as savez, - shape as shape, - size as size, signedinteger as signedinteger, unsignedinteger as unsignedinteger, ) @@ -307,6 +310,7 @@ try: float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, float8_e8m0fnu as float8_e8m0fnu, + float4_e2m1fn as float4_e2m1fn, ) except ImportError: pass diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index dee61c145..b73a3b95b 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -ndim = _np.ndim +def ndim(a: ArrayLike) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -842,7 +842,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -shape = _np.shape +def shape(a: ArrayLike) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -size = _np.size +def size(a: ArrayLike, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., diff --git a/jax/version.py b/jax/version.py index 616950577..be20aca06 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import os import pathlib import subprocess -_version = "0.5.2" +_version = "0.5.3" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1acfbaf22..c747c0abb 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") -if _cuda_linalg: - for _name, _value in _cuda_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="CUDA", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "cu_lu_pivots_to_permutation") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _hip_linalg: - for _name, _value in _hip_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "hip_lu_pivots_to_permutation") + +def batch_partitionable_targets() -> list[str]: + targets = [] + if _cuda_linalg: + targets.append("cu_lu_pivots_to_permutation") + if _hip_linalg: + targets.append("hip_lu_pivots_to_permutation") + return targets diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index a40c6bf93..efb58f9a4 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin @@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") -if _cublas: - for _name, _value in _cublas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") -if _cusolver: - for _name, _value in _cusolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: + if module: + registrations[platform].extend( + (*i, 0) for i in module.registrations().items()) + for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _cuhybrid: - for _name, _value in _cuhybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=1) -if _hipblas: - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") +def batch_partitionable_targets() -> list[str]: + targets = [] + for module in [_cusolver, _hipsolver]: + if module: + targets.extend( + name for name in module.registrations() + if name.endswith("_ffi") + ) + for module in [_cuhybrid, _hiphybrid]: + if module: + targets.extend(name for name in module.registrations()) + return targets -if _hipsolver: - for _name, _value in _hipsolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - -if _hiphybrid: - for _name, _value in _hiphybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=1) def initialize_hybrid_kernels(): if _cuhybrid: @@ -70,6 +64,7 @@ def initialize_hybrid_kernels(): if _hiphybrid: _hiphybrid.initialize() + def has_magma(): if _cuhybrid: return _cuhybrid.has_magma() diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d397557df..d8645041c 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX import math from functools import partial +from typing import Any import jaxlib.mlir.ir as ir import numpy as np -from jaxlib import xla_client - from .hlo_helpers import custom_call, mk_result_types_and_shapes from .plugin_support import import_from_plugin @@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") -if _cusparse: - for _name, _value in _cusparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hipsparse: - for _name, _value in _hipsparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations # pytype: disable=bad-return-type cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 633cd07ab..58a83d9b0 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -325,11 +325,18 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) -def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version): +def _get_full_wheel_name( + package_name, + no_abi, + platform_independent, + platform_name, + cpu_name, + wheel_version, + py_freethreaded): if no_abi or platform_independent: wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" else: - wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl" python_version = HERMETIC_PYTHON_VERSION.replace(".", "") return wheel_name_template.format( package_name = package_name, @@ -339,6 +346,7 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na wheel_platform_tag = "any" if platform_independent else "_".join( PLATFORM_TAGS_DICT[platform_name, cpu_name], ), + free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) def _get_source_distribution_name(package_name, wheel_version): @@ -352,6 +360,7 @@ def _jax_wheel_impl(ctx): override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value output_path = ctx.attr.output_path[BuildSettingInfo].value git_hash = ctx.attr.git_hash[BuildSettingInfo].value + py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value executable = ctx.executable.wheel_binary if include_cuda_libs and not override_include_cuda_libs: @@ -387,6 +396,7 @@ def _jax_wheel_impl(ctx): platform_name = platform_name, cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) wheel_file = ctx.actions.declare_file(output_path + "/" + wheel_name) @@ -463,6 +473,7 @@ _jax_wheel = rule( "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), }, implementation = _jax_wheel_impl, executable = False, @@ -522,9 +533,10 @@ def jax_wheel( # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. cpu = select({ "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:macos_x86_64": "x86_64", "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", - "@platforms//cpu:x86_64": "x86_64", + "//jaxlib/tools:linux_aarch64": "aarch64", + "//jaxlib/tools:linux_x86_64": "x86_64", }), source_files = source_files, ) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index c5a59e314..330fcb992 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from typing import Any -from jaxlib import xla_client +import numpy as np from .cpu import _lapack from .cpu._lapack import eig from .cpu._lapack import schur -for _name, _value in _lapack.registrations().items(): - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target( - _name, _value, platform="cpu", api_version=api_version - ) - EigComputationMode = eig.ComputationMode SchurComputationMode = schur.ComputationMode @@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = { } +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + return {"cpu": [ + (name, value, int(name.endswith("_ffi"))) + for name, value in _lapack.registrations().items() + ]} + + +def batch_partitionable_targets() -> list[str]: + return [name for name in _lapack.registrations() if name.endswith("_ffi")] + + def prepare_lapack_call(fn_base, dtype): """Initializes the LAPACK library and returns the LAPACK target name.""" _lapack.initialize() diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 964039d82..dbc829832 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -347,7 +347,8 @@ def MosaicGPU_AsyncStoreOp : Op:$commit_group ); let assemblyFormat = [{ diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 25ba46dfe..4b5ed3493 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -538,6 +538,16 @@ def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { let hasVerifier = 1; } +def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + DefaultValuedAttr:$approx + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { let arguments = (ins Variadic:$input); let results = (outs AnyVectorOfNonZeroRank:$output); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 69c29e51f..5a8042120 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/Diagnostics.h" #include "mlir/include/mlir/IR/IRMapping.h" #include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -1164,6 +1165,13 @@ LogicalResult LogBufferOp::verify() { return success(); } +LogicalResult ReciprocalOp::verify() { + if (!getType().getElementType().isF32()) { + return emitOpError("Not implemented: Reciprocal op for non-f32 dtypes"); + } + return success(); +} + void PackSubelementsOp::build(OpBuilder &builder, OperationState &state, const VectorType output_type, const ArrayRef padded_sources, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 62fa622ae..17e8d12b5 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -684,10 +684,8 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(layouts_in.front().has_value()); TPU_ASSERT_OP(layouts_out.front().has_value()); auto extf_op = cast(op); - if (layouts_in.front()->bitwidth() != 16 || - layouts_out.front()->bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 16-bit to 32-bit conversion supported"); + if (layouts_out.front()->bitwidth() != 32) { + return op.emitOpError("Not implemented: Only support conversion to 32-bit"); } ImplicitLocOpBuilder builder(op.getLoc(), &op); FAILUREOR_ASSIGN_OR_RETURN( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 33912fddf..082e1204c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1674,8 +1674,7 @@ class VectorLayoutInferer { auto some_layout = getLayout(op->getOperand(0)); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); if (dyn_cast(op)) { - TPU_CHECK_OP(src_bitwidth == 16 && dst_bitwidth == 32, - "Only 16-bit to 32-bit extensions supported"); + TPU_CHECK_OP(dst_bitwidth == 32, "Only supported extensions to 32-bit"); } auto &layout = *some_layout; Layout src_layout; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 312605df8..2f415912f 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -182,8 +182,9 @@ void callback_complete(CUcontext context, uint32_t streamId, // Convert integer nanoseconds to floating point milliseconds to match // the interface of the events-based profiler. double duration_ms = (kernel->end - kernel->start) / 1e6; + const char* kernel_name = kernel->name; profiler_state.timings.push_back( - std::make_tuple(kernel->name, duration_ms)); + std::make_tuple(kernel_name, duration_ms)); } } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { // no more records available diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index b95483b22..baf996d50 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -19,7 +19,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( - "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", ) load( @@ -186,6 +186,14 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "macos_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + ":macos", + ], +) + selects.config_setting_group( name = "win_amd64", match_all = [ @@ -194,6 +202,22 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "linux_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], +) + +selects.config_setting_group( + name = "linux_aarch64", + match_all = [ + ":arm64", + "@platforms//os:linux", + ], +) + string_flag( name = "jaxlib_git_hash", build_setting_default = "", diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 9c7f61fc2..4c50cff16 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -70,6 +70,12 @@ def build_wheel( env = dict(os.environ) if git_hash: env["JAX_GIT_HASH"] = git_hash + if is_windows() and ( + "USERPROFILE" not in env + and "HOMEDRIVE" not in env + and "HOMEPATH" not in env + ): + env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] + (["-w"] if build_wheel_only else []), diff --git a/pyproject.toml b/pyproject.toml index e32b14a89..a1b9e7dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ filterwarnings = [ # https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358 "ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning", + # TODO(b/401588349): Remove this once transparent hugepages are enabled. + "ignore:Transparent hugepages", + # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also # check for warnings and do not check this list. Most likely, you should diff --git a/tests/BUILD b/tests/BUILD index d5d6c7be5..0ffa68ed8 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -838,6 +838,13 @@ jax_multiplatform_test( jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], + backend_tags = { + "gpu": [ + # disable suspicious leaking in cupti/cuda, + # TODO: remove this once b/372714955 is resolved. + "noasan", + ], + }, enable_backends = [ "cpu", "gpu", diff --git a/tests/api_test.py b/tests/api_test.py index e8fef8011..c9cf28e0a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -20,6 +20,7 @@ from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy +import dataclasses import enum import functools from functools import partial @@ -57,6 +58,7 @@ from jax._src import xla_bridge from jax._src import debugging from jax._src import pjit as pjit_lib from jax._src.ad_checkpoint import saved_residuals +from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled @@ -1974,6 +1976,48 @@ class APITest(jtu.JaxTestCase): x2=jnp.ones(2, dtype=jnp.float32) ) + def test_vmap_inconsistent_sizes_constructs_proper_error_message_starargs(self): + # regression test for https://github.com/jax-ml/jax/issues/26908 + def f(x, *args): + return x - functools.reduce(jnp.add, args) + + with self.assertRaisesRegex( + ValueError, + "vmap got inconsistent sizes for array axes to be mapped:" + ): + jax.vmap(f)(jnp.ones(4), jnp.ones(2), jnp.ones(2)) + + def test_vmap_sentinel(self): + + @jax.tree_util.register_dataclass + @dataclasses.dataclass + class Foo: + x: jax.Array + + def __init__(self, x): + nonlocal saw_sentinel + if x is jax._src.api_util.SENTINEL: + saw_sentinel += 1 + self.x = x + + x = jnp.arange(10) + + # assert that sentinel is seen once for vmap in_axes + saw_sentinel = 0 + jax.vmap(lambda f: f.x)(Foo(x)) + self.assertEqual(saw_sentinel, 1) + + # assert that sentinel is seen once for vmap out_axes + saw_sentinel = 0 + jax.vmap(Foo)(x) + self.assertEqual(saw_sentinel, 1) + + # assert that sentinel is seen twice with vmap in_axes and out_axes + saw_sentinel = 0 + jax.vmap(lambda f: Foo(f.x + 1))(Foo(x)) + self.assertEqual(saw_sentinel, 2) + + def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) @@ -4721,6 +4765,19 @@ class APITest(jtu.JaxTestCase): check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + def test_deferred_primal_with_direct_linearize(self): + def my_sin_lin(nzs, x): + nz, = nzs + return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + + my_sin_p = core.Primitive("my_sin_p") + my_sin_p.def_impl(lax.sin) + my_sin_p.def_abstract_eval(lambda x: x) + ad_internal.primitive_linearizations[my_sin_p] = my_sin_lin + + with config.use_direct_linearize(True): + jax.grad(my_sin_p.bind)(1.0) # doesn't crash + class RematTest(jtu.JaxTestCase): @@ -6435,14 +6492,10 @@ class JaxprTest(jtu.JaxTestCase): e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( - { lambda ; g_:f32[] h:f32[] i:f32[] j:f32[]. let - k:f32[] = sub j h - in (k,) } - { lambda ; l:f32[] m_:f32[] n:f32[] o:f32[]. let - p:f32[] = add n l - in (p,) } + { lambda ; g:f32[] h:f32[] i:f32[]. let j:f32[] = sub i g in (j,) } + { lambda ; k:f32[] l:f32[] m:f32[]. let n:f32[] = add l k in (n,) } ) - ] e a a c d + ] e a c d in (f,) }""" jaxpr = api.make_jaxpr(f)(jnp.float32(3.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) @@ -8104,6 +8157,29 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose( api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) + def test_resolve_kwargs_error_message(self): + @jax.custom_jvp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomVJPTest(jtu.JaxTestCase): @@ -9762,6 +9838,33 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose( api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) + def test_resolve_kwargs_error_message(self): + @jax.custom_vjp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + self.fail("should not be executed") + + def f_bwd(res, cts): + self.fail("should not be executed") + + f.defvjp(f_fwd, f_bwd) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + def transpose_unary(f, x_example): def transposed(y): @@ -10490,6 +10593,29 @@ class CustomDceTest(jtu.JaxTestCase): self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) + def test_resolve_kwargs_error_message(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, *, z=None): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomVmapTest(jtu.JaxTestCase): @@ -11115,6 +11241,29 @@ class CustomVmapTest(jtu.JaxTestCase): out, f_vjp = jax.vjp(f, xs, y) f_vjp(out) # Doesn't crash. + def test_resolve_kwargs_error_message(self): + @jax.custom_batching.custom_vmap + def f(x, y, *, z=None): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index d6abe8bec..52d494904 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -66,6 +66,14 @@ _count_colocated_python_specialization_cache_miss = jtu.count_events( class ColocatedPythonTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + self.skipTest( + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" + ) + def testMakeColocatedPythonProgram(self): def add_one(x): return x + 1 @@ -382,8 +390,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): del colocated_python._testing_global_state def testStringProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -425,8 +431,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): ) def testBinaryDataProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") diff --git a/tests/core_test.py b/tests/core_test.py index 5fc906bd3..c46d493bd 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -43,13 +43,14 @@ __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): return jit(f)(*args) -@util.curry def core_call(f, *args): args, in_tree = jax.tree.flatten(args) dbg = debug_info("core_call_test", f, args, {}) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree) out = core.call_p.bind(f, *args) return jax.tree.unflatten(out_tree(), out) +# call = core_call +core_call = util.curry(core_call) @util.curry def core_closed_call(f, *args): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 2c5d6a772..a39b53c3a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -896,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," + if config.use_direct_linearize.value else "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ @@ -1324,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" + if config.use_direct_linearize.value else "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ @@ -1608,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", + "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," + if config.use_direct_linearize.value else + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 8127aed7a..87380443f 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -73,6 +73,12 @@ if dtypes.float8_e8m0fnu is not None: float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes +fp4_dtypes = [] +if dtypes.float4_e2m1fn is not None: + fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)] +float_dtypes += fp4_dtypes +custom_float_dtypes += fp4_dtypes + complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')] @@ -238,6 +244,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t1 in intn_dtypes: continue + if t1 in fp4_dtypes: + continue self.assertEqual(np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128)) @@ -247,6 +255,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t2 in intn_dtypes: continue + if t2 in fp4_dtypes: + continue # Symmetry self.assertEqual(dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1)) @@ -261,6 +271,8 @@ class DtypesTest(jtu.JaxTestCase): # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. if t in fp8_dtypes: continue + if t in fp4_dtypes: + continue if t in intn_dtypes or i in intn_dtypes: continue self.assertEqual(t, dtypes.promote_types(t, i)) @@ -951,10 +963,12 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest("XLA support for int2 and int4 is incomplete.") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( - dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind]) + dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected) @@ -971,6 +985,18 @@ class TestPromotionTables(jtu.JaxTestCase): ".*8-bit floats do not support implicit promotion"): x + y + @jax.numpy_dtype_promotion('standard') + def testFloat4PromotionError(self): + for dtype in fp4_dtypes: + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + # TPU does not support float4_e2m1fn. + continue + x = jnp.array(1, dtype=dtype) + y = jnp.array(1, dtype='float32') + with self.assertRaisesRegex(dtypes.TypePromotionError, + ".*4-bit floats do not support implicit promotion"): + x + y + @jax.numpy_dtype_promotion('standard') @jtu.run_on_devices('tpu') def testInt2PromotionError(self): @@ -995,6 +1021,8 @@ class TestPromotionTables(jtu.JaxTestCase): def testBinaryNonPromotion(self, dtype, weak_type, promotion): if dtype in fp8_dtypes: self.skipTest("XLA support for float8 is incomplete.") + if dtype in fp4_dtypes: + self.skipTest("XLA support for float4 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") # Regression test for https://github.com/jax-ml/jax/issues/6051 @@ -1027,6 +1055,8 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest('XLA support for int2 is incomplete.') if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest('TPU does not support float8_e8m0fnu.') + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest('TPU does not support float4_e2m1fn.') val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 8ac435cbb..5cdde30b1 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -148,7 +148,8 @@ class ErrorCheckTests(jtu.JaxTestCase): with self.assertRaisesRegex(JaxValueError, "x must be less than 10"): error_check.raise_if_error() - def test_error_check_works_with_scan(self): + @parameterized.product(jit=[True, False]) + def test_error_check_works_with_scan(self, jit): def f(carry, x): error_check.set_error_if(x >= 4, "x must be less than 4") return carry + x, x + 1 @@ -156,6 +157,9 @@ class ErrorCheckTests(jtu.JaxTestCase): def body(init, xs): return jax.lax.scan(f, init=init, xs=xs) + if jit: + body = jax.jit(body) + init = jnp.int32(0) xs = jnp.arange(5, dtype=jnp.int32) _ = body(init, xs) @@ -166,5 +170,26 @@ class ErrorCheckTests(jtu.JaxTestCase): _ = body(init, xs) error_check.raise_if_error() # should not raise error + @parameterized.product(jit=[True, False]) + def test_raise_if_error_fails_in_traced_context(self, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), 1, dtype=jnp.int32) + f(x) + with self.assertRaises( + ValueError, + msg=( + "raise_if_error() should not be called within a traced context," + " such as within a jitted function." + ), + ): + jax.jit(error_check.raise_if_error)() + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_test.py b/tests/export_test.py index 60c96fca4..6baecebe1 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1014,6 +1014,8 @@ class JaxExportTest(jtu.JaxTestCase): self.skipTest(f"TODO: serialization not supported for {str(dtype)}") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") @jax.jit def f_jax(x): return x + x diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 921b8544b..af0b18b02 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -14,10 +14,6 @@ from functools import partial from absl.testing import absltest -import os - -os.environ["XLA_FLAGS"] = \ - "--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true" import numpy as np import jax @@ -30,7 +26,6 @@ from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, check_is_flash_attention, check_cudnn_version, - get_large_negative_number, MaskType, AttentionLayout, ) @@ -90,6 +85,9 @@ cast_to_representable = partial( quantize = partial(quantize_to_fp8, scale=1) +def get_large_negative_number(dtype): + return 0.7 * jnp.finfo(dtype).min + def sdpa_train(query: Array, key: Array, value: Array, @@ -168,7 +166,7 @@ def sdpa_ref(query: Array, B, T, qN, H = query.shape _, _, kN, _ = key.shape - logits = jnp.einsum("bqhd,bkhd->bhqk", query, key) + logits = jnp.einsum("bqhd,bkhd->bhqk", query, key, preferred_element_type=jnp.float32) if scale != 1.0: logits = logits * scale if mask_type == MaskType.CAUSAL: @@ -182,28 +180,31 @@ def sdpa_ref(query: Array, bias = get_sliding_window_mask(logits, sliding_window_length) if mask is not None: large_negative_number = get_large_negative_number(logits.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) + mask = jnp.where(mask, 0, large_negative_number) + # combine bias and mask if bias is None: bias = mask elif mask is not None: + bias = bias.astype(logits.dtype) bias += mask + # apply bias to logits if bias is not None: if bias.shape != logits.shape: bias = jnp.broadcast_to(bias, logits.shape) logits = logits + bias.astype(logits.dtype) - probs = jax.nn.softmax(logits, axis=-1) + probs = jax.nn.softmax(logits, axis=-1).astype(query.dtype) if dropout_rate > 0.: keep_prob = 1.0 - dropout_rate dropout_rng = jax.random.key(0) keep = jax.random.bernoulli(dropout_rng, keep_prob, probs.shape) probs = jax.lax.select(keep, probs / keep_prob, jnp.zeros_like(probs)) - encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value) + encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value, preferred_element_type=jnp.float32) if mask_type == MaskType.PADDING: # cuDNN padding mask generation will mask out output accordingly # make sure the behavior is the same encoded_mask = get_encoded_padding_mask(encoded) encoded = encoded * encoded_mask - return encoded + return encoded.astype(query.dtype) def sdpa_train_ref(query: Array, key: Array, @@ -239,7 +240,7 @@ def sdpa_train_fp8( f_p = partial( dot_product_attention, scale=scale, mask_type=mask_type, use_fp8=True ) - return f_p(query, key, value, None, None, None, None, fp8_metas) + return f_p(query, key, value, fp8_params=fp8_metas) out, sdpa_vjp = jax.vjp( dot_product_attention_fp8, query, key, value, fp8_metas @@ -274,7 +275,7 @@ class DotProductAttentionTest(jtu.JaxTestCase): use_mask=[False, True], use_bias=[False, True], mask_type=[MaskType.NO_MASK], - dropout_rate=[0, 0.5], + dropout_rate=[0], scale=[0.5], dtype=[jnp.float16, jnp.bfloat16] ) @@ -351,18 +352,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): jitted_sdpa_train(query, key, value, grad, bias, mask) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ jitted_sdpa_train_ref(query, key, value, grad, bias, mask) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - if seq_len > 512: - # query_grad in flash attention is not deterministic - self.assertArraysAllClose( - query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - else: - self.assertArraysAllClose( - query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) self.assertArraysAllClose( - key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) self.assertArraysAllClose( - value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose( + value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_inference(self): @@ -381,9 +377,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): with Mesh(devices, ("dp", "tp")) as mesh: qkv_spec = PartitionSpec("dp", None, "tp", None) qkv_sharding = NamedSharding(mesh, qkv_spec) - replicated = NamedSharding(mesh, PartitionSpec()) in_shardings = ( - qkv_sharding, qkv_sharding, qkv_sharding, replicated, replicated) + qkv_sharding, qkv_sharding, qkv_sharding) out_shardings = qkv_sharding query = jax.device_put(query, qkv_sharding) key = jax.device_put(key, qkv_sharding) @@ -403,15 +398,14 @@ class DotProductAttentionTest(jtu.JaxTestCase): out_shardings=out_shardings ) - out = jitted_sdpa_inference(query, key, value, None, None) - out_ref = jitted_sdpa_inference_ref(query, key, value, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + out = jitted_sdpa_inference(query, key, value) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) @jtu.run_on_devices("cuda") def test_sdpa_var_seq(self): if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") - self.skipTest("Skip before fixed.") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) @@ -432,13 +426,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad) = \ - jitted_sdpa_train(query, key, value, grad, None, None) + jitted_sdpa_train(query, key, value, grad) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_broadcast_bias_and_dbias(self): @@ -472,9 +466,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): qkv_sharding = NamedSharding(mesh, qkv_spec) bias_spec = PartitionSpec("tp", None, None) bias_sharding = NamedSharding(mesh, bias_spec) - replicated = NamedSharding(mesh, PartitionSpec()) in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, - qkv_sharding, bias_sharding, replicated) + qkv_sharding, bias_sharding) out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding)) query = jax.device_put(query, qkv_sharding) key = jax.device_put(key, qkv_sharding) @@ -496,14 +489,14 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad, bias_grad) = \ - jitted_sdpa_train(query, key, value, grad, bias, None) + jitted_sdpa_train(query, key, value, grad, bias) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, bias, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad, bias) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=2e-1, atol=2e-1) @jtu.sample_product( batch_size=[1, 16], @@ -573,13 +566,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad) = \ - jitted_sdpa_train(query, key, value, grad, None, None) + jitted_sdpa_train(query, key, value, grad) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_large_head_size(self): @@ -607,12 +600,12 @@ class DotProductAttentionTest(jtu.JaxTestCase): sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0) ) - out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None) - out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None) + out_ans, grads_ans = sdpa_train_ans(query, key, value, grad) + out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad) self.assertArraysAllClose(out_ref, out_ans) - self.assertArraysAllClose(grads_ref[0], grads_ans[0]) - self.assertArraysAllClose(grads_ref[1], grads_ans[1]) - self.assertArraysAllClose(grads_ref[2], grads_ans[2]) + self.assertArraysAllClose(grads_ref[0], grads_ans[0], rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(grads_ref[1], grads_ans[1], rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(grads_ref[2], grads_ans[2], rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_packed_layout(self): @@ -679,7 +672,7 @@ class DotProductAttentionTest(jtu.JaxTestCase): kv_seqlen = q_seqlen.copy() mask = generate_padding_mask(segment_ids, q_seqlen.shape[1], query.shape, query.dtype) - bias = generate_segment_mask(segment_ids, query.dtype) + bias = generate_segment_mask(segment_ids, jnp.float32) devices = np.array(jax.local_devices()[:4]) devices = devices.reshape((2, 2)) @@ -757,8 +750,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): value = jax.random.normal(k2, (B, S, N, H), dtype=dtype) grad = jax.random.normal(k3, (B, T, N, H), dtype=dtype) - btnh_fn = jax.jit(partial(sdpa_train_ref, scale=.5, - mask_type=MaskType.CAUSAL, dropout_rate=0.0)) + btnh_fn = jax.jit(partial(sdpa_train, scale=.5, + mask_type=MaskType.CAUSAL, is_bnth=False, dropout_rate=0.0)) out_ref, (dq_ref, dk_ref, dv_ref) = btnh_fn(query, key, value, grad) def _cvt(x): @@ -877,7 +870,7 @@ class DotProductAttentionF8Test(jtu.JaxTestCase): fp8_metas, ) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = ( - jitted_sdpa_train_ref(query, key, value, grad, None, None) + jitted_sdpa_train_ref(query, key, value, grad) ) self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-1, atol=5e-1) @@ -938,7 +931,7 @@ class DotProductAttentionF8Test(jtu.JaxTestCase): qkv_layout=qkv_layout, use_fp8=True, ) - return f_p(query, key, value, None, None, None, None, fp8_metas) + return f_p(query, key, value, fp8_params=fp8_metas) jitted_sdpa_inference = jax.jit( dot_product_attention_fp8, diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4c61426c6..15fc37805 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2120,7 +2120,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash def testIssue804(self): - # https://github.com/google/jax/issues/804 + # https://github.com/jax-ml/jax/issues/804 num_devices = jax.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 50773e23b..98f10d9c0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, check_dtypes=False) + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes, + op=['ndim', 'shape', 'size'], + ) + def testNdimShapeSize(self, shape, dtype, op): + rng = jtu.rand_default(self.rng()) + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + x = rng(shape, dtype) + expected = np_op(x) + self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input. + self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input. + self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input. + + @jtu.sample_product( + shape=nonzerodim_shapes, + dtype=default_dtypes, + ) + def testSizeAlongAxis(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + axis = self.rng().randint(-len(shape), len(shape)) + np_op = partial(np.size, axis=axis) + jnp_op = partial(jnp.size, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + op=[jnp.ndim, jnp.shape, jnp.size], + ) + def testNdimShapeSizeNonArrayInput(self, op): + msg = f"{op.__name__} requires ndarray or scalar arguments" + with self.assertWarnsRegex(DeprecationWarning, msg): + op([1, 2, 3]) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest. diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 575362895..96d48dcd3 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -273,6 +273,11 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): with self.assertRaises(TypeError): lsp_special.beta(x=1, y=1) + def testExpnTracerLeaks(self): + # Regression test for https://github.com/jax-ml/jax/issues/26972 + with jax.checking_leaks(): + lsp_special.expi(jnp.ones(())) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8497bf389..ad6b2a0bc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dot_dimension_numbers": ragged_dot_dimension_numbers, + "err_msg": err_msg, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0, 1], + rhs_group_dimensions=[0], + ), + "ragged_dot_general expects exactly one lhs ragged dimension", + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires lhs ragged dimension numbers to " + "be nonnegative and less than the number of axes of the lhs" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + r"expected group_sizes to have shape \(3,\), got \(2, 3\)", + ), + ( + [19, 17, 11, 5], + [3, 19, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0], [1])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [19, 11, 17, 5], + [19, 17, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])), + lhs_ragged_dimensions=[3], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [17, 19, 11, 5], + [17, 19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(17, 3\), " + r"got \(19, 3\)" + ), + ), + ( + [19, 11, 5], + [19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [1]), ([0], [0])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 3], + [3, 3, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[1], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + "expected rhs group dimension size to be 2, got 3", + ), + ( + [2, 11, 5], + [3, 2, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [2]), ([0], [1])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[], + ), + ( + "ragged_dot_general requires exactly one rhs group dimension " + "when lhs ragged dimension is noncontracting" + ), + ), + ] + ) + def test_ragged_dot_general_shape_inference_failure( + self, lhs_shape, rhs_shape, group_sizes_shape, + ragged_dot_dimension_numbers, err_msg): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + with self.assertRaisesRegex(TypeError, err_msg): + lax.ragged_dot_general(lhs, rhs, group_sizes, + ragged_dot_dimension_numbers) + + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dnums": ragged_dnums, + "out_shape": out_shape, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + (11, 7), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + (3, 11, 7), + ), + ] + ) + def test_ragged_dot_general_shape_inference_success( + self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + self.assertEqual( + lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, + out_shape, + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/memories_test.py b/tests/memories_test.py index acb13336e..a08c5f36c 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1834,7 +1834,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase): self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") - self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") + self.assertIn("dynamic-slice-start", compiled_text) compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 94d5d6714..ba9d23fa5 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -18,6 +18,7 @@ from typing import Callable from absl.testing import parameterized import jax +from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter @@ -824,6 +825,43 @@ class DialectLoweringTest(MosaicGpuTest): ) ) + @parameterized.parameters( + (arith.ExtFOp, jnp.bfloat16, jnp.float32), + (arith.ExtSIOp, jnp.int16, jnp.int32), + (arith.ExtUIOp, jnp.int16, jnp.uint32), + (arith.FPToSIOp, jnp.float32, jnp.int32), + (arith.FPToUIOp, jnp.float32, jnp.uint32), + (arith.SIToFPOp, jnp.int16, jnp.float32), + (arith.TruncFOp, jnp.float32, jnp.float16), + (arith.TruncIOp, jnp.int32, jnp.int16), + (arith.UIToFPOp, jnp.uint32, jnp.float32), + ) + def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): + shape = (4, 32) + + with ir.InsertionPoint(self.module.body): + scalar_in_ty = mgpu_utils.dtype_to_ir_type(in_dtype) + scalar_out_ty = mgpu_utils.dtype_to_ir_type(out_dtype) + in_ty = ir.VectorType.get(shape, scalar_in_ty) + out_ty = ir.VectorType.get(shape, scalar_out_ty) + if ir.IntegerType.isinstance(scalar_in_ty): + zero = ir.IntegerAttr.get(scalar_in_ty, 0) + else: + zero = ir.FloatAttr.get(scalar_in_ty, 0) + splat_zero = arith.ConstantOp( + in_ty, ir.DenseElementsAttr.get_splat(in_ty, zero) + ) + op(out_ty, splat_zero) + + mgpu.infer_layout(self.module) + mgpu.lower_mgpu_dialect(self.module, None) + + conversion_ops = find_if(self.module, lambda o: isinstance(o, op)) + # This is a splat, so we expect a single conversion op involving a scalar + # after lowering. + self.assertLen(conversion_ops, 1) + self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f3d94d917..cc654eb2b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -485,12 +485,20 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16)) - def test_store_untiled(self, dtype): + @parameterized.product(dtype=[jnp.float16, jnp.float32], + tiled_layout=[False, True], + transposed_smem=[False, True]) + def test_store_untiled(self, dtype, tiled_layout, transposed_smem): def kernel(ctx, out, _): del ctx - iota_tensor(64, 64, dtype).store_untiled(out) + if transposed_smem: + out = memref_transpose(out, (1, 0)) + iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled( + out, vector_store=not transposed_smem + ) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) + if transposed_smem: + expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() @@ -652,7 +660,8 @@ class WGMMATest(TestCase): k_steps=(1, 2), swizzle=(32, 64, 128), jax_out_dtype=(jnp.float16, jnp.float32), - small_rhs_tile=(False, True,), + rhs_tiling_kind=("large", "small", "small+no_transpose"), + lhs_tiling_kind=("large", "small", "small+no_transpose"), ) def test_wgmma_basic( self, @@ -664,12 +673,17 @@ class WGMMATest(TestCase): rhs_transpose, swizzle, jax_out_dtype, - small_rhs_tile, + rhs_tiling_kind, + lhs_tiling_kind, ): if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: - raise self.skipTest("Only f16 input is supported for f16 output.") - if swizzle != 128 and lhs_transpose: - raise self.skipTest("Transpose only supported in 128B swizzled WGMMA") + self.skipTest("Only f16 input is supported for f16 output.") + if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": + self.skipTest("Transpose only supported in 128B swizzled WGMMA") + if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: + self.skipTest("No transpose happening anyway") + if lhs_tiling_kind == "small+no_transpose" and not lhs_transpose: + self.skipTest("No transpose happening anyway") in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) @@ -695,19 +709,20 @@ class WGMMATest(TestCase): k = nk_tile * k_steps assert m % 64 == 0 and n % nk_tile == 0 - small_nk_tile = 8 - rhs_tiling = ( - (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) - ) + small_rhs_tile = rhs_tiling_kind != "large" + transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" + rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + small_lhs_tile = lhs_tiling_kind != "large" + transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose" + lhs_tiling = (8, nk_tile) if small_lhs_tile else (64, nk_tile) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch - lhs_transform = (mgpu.TileTransform((64, nk_tile)),) - if lhs_transpose: - assert nk_tile == 64 # Make sure we didn't have to transpose tiling. + lhs_transform = (mgpu.TileTransform(lhs_tiling),) + if lhs_transpose and transpose_lhs_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: + if rhs_transpose and transpose_rhs_tiles: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, @@ -727,9 +742,11 @@ class WGMMATest(TestCase): barriers[i].wait() init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) + lhs_smem = memref_transpose(lhs_smem, perm) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) + rhs_smem = memref_transpose(rhs_smem, perm) acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) @@ -744,15 +761,19 @@ class WGMMATest(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + if transpose_rhs_tiles: + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) + else: + rhs_smem_shape = tile_shape(y_shape, rhs_tiling) + if transpose_lhs_tiles: + lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling + lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling) + else: + lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct( - (m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype), + jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(2), ] z = mgpu.as_gpu_kernel( @@ -890,7 +911,7 @@ class TCGen05Test(TestCase): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") @parameterized.product( - lhs_transpose=(False,), # TODO(apaszke): True + lhs_transpose=(False, True), rhs_transpose=(False, True), in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation @@ -898,7 +919,8 @@ class TCGen05Test(TestCase): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), - small_rhs_tile=(False, True), + rhs_transpose_tiles=(False, True), + lhs_transpose_tiles=(False, True), ) def test_mma_basic( self, @@ -910,28 +932,24 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, - small_rhs_tile, + rhs_transpose_tiles, + lhs_transpose_tiles, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: - raise self.skipTest("Only f16 input is supported for f16 output.") + self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) - m_tile = 128 - nk_tile = swizzle // bytewidth(in_mlir_dtype) - k = nk_tile * k_steps - assert m % m_tile == 0 and n % nk_tile == 0 - - small_nk_tile = 8 - rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch - lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),) - if lhs_transpose: - assert nk_tile == m_tile # Make sure we didn't have to transpose tiling + lhs_transform = (mgpu.TileTransform(lhs_tiling),) + if lhs_transpose_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: + if rhs_transpose_tiles: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, @@ -950,10 +968,14 @@ class TCGen05Test(TestCase): barriers[0].wait() barriers[1].wait() with mgpu.single_thread(): + if lhs_transpose_tiles: + lhs_smem = memref_transpose(lhs_smem, (1, 0, 2, 3)) if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) + if rhs_transpose_tiles: + rhs_smem = memref_transpose(rhs_smem, (1, 0, 2, 3)) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, ) @@ -972,15 +994,21 @@ class TCGen05Test(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + if rhs_transpose_tiles: + rhs_smem_shape = ( + y_shape[1] // rhs_tiling[1], y_shape[0] // rhs_tiling[0], *rhs_tiling, + ) + else: + rhs_smem_shape = tile_shape(y_shape, rhs_tiling) + if lhs_transpose_tiles: + lhs_smem_shape = ( + x_shape[1] // lhs_tiling[1], x_shape[0] // lhs_tiling[0], *lhs_tiling, + ) + else: + lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct( - tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype), + jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype), ] @@ -993,15 +1021,14 @@ class TCGen05Test(TestCase): np.testing.assert_allclose(z, ref, atol=atol) @parameterized.product( - lhs_transpose=(False,), # TODO(apaszke): True - rhs_transpose=(True,), + lhs_transpose=(False, True), + rhs_transpose=(False, True), in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), - small_rhs_tile=(False, True), ) def test_mma_collective( self, @@ -1013,42 +1040,27 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, - small_rhs_tile, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) m_block_tile = m // 2 - m_tma_tile = 128 n_block_tile = n // 2 - nk_tma_tile = swizzle // bytewidth(in_mlir_dtype) - k = nk_tma_tile * k_steps - assert m % m_tma_tile == 0 and n % nk_tma_tile == 0 + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps index = ir.IndexType.get() - small_nk_tile = 8 if rhs_transpose else 16 - rhs_tiling = ( - (small_nk_tile, nk_tma_tile) - if small_rhs_tile - else (nk_tma_tile, nk_tma_tile) - ) + tiling = (8, swizzle_elems) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch - lhs_transform = (mgpu.TileTransform((m_tma_tile, nk_tma_tile)),) - if lhs_transpose: - assert nk_tma_tile == m_tma_tile # Make sure we didn't have to transpose tiling - lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: - rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) block_id = gpu.cluster_block_id(gpu.Dimension.x) ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, swizzle=swizzle, - gmem_transform=lhs_transform, + gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[0], collective=gpu.Dimension.x, partitioned=1 if lhs_transpose else 0, # Split non-contracting dim. @@ -1057,7 +1069,7 @@ class TCGen05Test(TestCase): src_ref=rhs, dst_ref=rhs_smem, swizzle=swizzle, - gmem_transform=rhs_transform, + gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[1], collective=gpu.Dimension.x, partitioned=0 if rhs_transpose else 1, # Split non-contracting dim. @@ -1068,9 +1080,9 @@ class TCGen05Test(TestCase): barriers[0].wait() barriers[1].wait() if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True ) @@ -1086,20 +1098,15 @@ class TCGen05Test(TestCase): return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) + x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k) x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) + y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling scratch_shape = [ - jax.ShapeDtypeStruct( - tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), - in_jax_dtype, - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n_block_tile // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype, collective=True), ] @@ -2186,15 +2193,18 @@ class LayoutTest(TestCase): dtype=[jnp.int8, jnp.int16, jnp.int32], swizzle=[16, 32, 64, 128], num_col_tiles=[1, 2, 3], + row_tiling=[8, 64], ) - def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): + def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles, row_tiling): + if (not load_tiled or not load_tiled) and row_tiling != 64: + self.skipTest("Old code path does not support this") mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) col_tiling = swizzle // bw if col_tiling % 8: self.skipTest("WGMMA layout requires col_tiling % 8 == 0") m, n = 128, col_tiling * num_col_tiles - tiling = (64, col_tiling) + tiling = (row_tiling, col_tiling) tiled_layout = fa._tiled_wgmma_layout((m, n)) load_layout = tiled_layout if load_tiled else mgpu.TILED_LAYOUT_WGMMA store_layout = tiled_layout if store_tiled else mgpu.TILED_LAYOUT_WGMMA @@ -2277,6 +2287,47 @@ class LayoutTest(TestCase): ) np.testing.assert_array_equal(f(x), x) + @parameterized.product( + dtype=[jnp.int16], # TODO(apaszke): More dtypes + # TODO(apaszke): swizzle=64 <- not implemented in transfer_tiled right now + swizzle=[16, 32, 128], + ) + def test_transpose_tiled(self, dtype, swizzle): + mlir_dtype = utils.dtype_to_ir_type(dtype) + bw = bytewidth(mlir_dtype) + col_tiling = swizzle // bw + m, n = 128, 256 + tiling = (8, col_tiling) + transpose_layout = fa.WGMMA_TRANSPOSED_LAYOUT + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=fa.TILED_LAYOUT_WGMMA + ) + smem_out_t = memref_transpose(smem_out, (1, 0, 3, 2)) + t.to_layout(transpose_layout).store_tiled(smem_out_t, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + x = ( + np.arange(m * n, dtype=dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + y_ref = ( + np.arange(m * n, dtype=dtype) + .reshape(m, n) + .T.reshape(n // tiling[0], tiling[0], m // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, y_ref, [x, y_ref, mgpu.TMABarrier()], + )(x) + np.testing.assert_array_equal(y, y_ref) + @dataclasses.dataclass(frozen=True) class Tile: @@ -2704,6 +2755,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () transpose_a: bool = False transpose_b: bool = False + load_a_in_registers: bool = False result = [] for swizzle in [ @@ -2735,6 +2787,13 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): transforms_a=[Tile([64, k]), Swizzle(swizzle)], transforms_b=[Tile([k, k]), Swizzle(swizzle)], ), + TestCaseInput( + shape_a=[groups_m * 64, groups_k * k], + shape_b=[groups_k * k, groups_n * k], + shape_res=[groups_m * 64, groups_n * k], + transforms_a=[Tile([64, k]), Swizzle(swizzle)], + load_a_in_registers=True, + ), ]) # The below only works for 128-byte swizzling. Regardless of transposing, # TMA needs the size of the last dimension to be compatible with the @@ -2798,6 +2857,14 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) + # SMEM -> Registers + a_operand = a_smem_ref + zero_index = arith.constant(ir.IndexType.get(), 0) + if test_case.load_a_in_registers: + a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type) + zero_vector_indices = [zero_index] * len(test_case.shape_a) + a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices) + # Computation shape_result = ir.MemRefType(result_gmem_ref.type).shape result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type @@ -2809,7 +2876,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): ) result = mgpu_dialect.wgmma( accumulator, - a_smem_ref, + a_operand, b_smem_ref, transpose_a=test_case.transpose_a, transpose_b=test_case.transpose_b, @@ -2819,8 +2886,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): nvvm.wgmma_wait_group_sync_aligned(0) # Registers -> SMEM - zero_index = arith.constant(ir.IndexType.get(), 0) - vector.store(result, result_smem_ref, [zero_index, zero_index]) + vector.store(result, result_smem_ref, [zero_index] * len(shape_result)) # SMEM -> GMEM mgpu_dialect.async_store( @@ -2909,4 +2975,4 @@ class SerializationTest(absltest.TestCase): if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 9a6c5c167..c510c2cfa 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -131,51 +131,6 @@ class MutableArrayTest(jtu.JaxTestCase): out = f() self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False) - @parameterized.parameters([True, False]) - def test_refs_in_vjps(self, jit): - def gradient_history_calculator_fwd(x, ref): - return x, ref - - def gradient_history_calculator_bwd(amax_history, grad_output): - amax_update = jnp.max(jnp.abs(grad_output)) - shifted = jnp.roll(amax_history[:], 1) - shifted = shifted.at[0].set(amax_update) - amax_history[:] = shifted - amax_from_history = jnp.max(amax_history[:]) - grad_output = grad_output / amax_from_history - return grad_output, None - - @jax.custom_vjp - def gradient_history_calculator(x, ref): - return x - - gradient_history_calculator.defvjp( - gradient_history_calculator_fwd, - gradient_history_calculator_bwd) - - class DotOp: - def __init__(self): - self.amax_history = core.mutable_array(jnp.zeros(5,)) - - def forward(self, x, y): - out = jnp.dot(x, y) - out = gradient_history_calculator(out, self.amax_history) - return out - - dot_op = DotOp() - x_top = jnp.ones((5,)) - y_top = jnp.ones((5,)) - - def loss(x, y): - return dot_op.forward(x, y).sum() - - if jit: - loss = jax.jit(loss) - - for i in range(3): - jax.grad(loss, (0,1))(x_top, y_top) - self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) - @parameterized.parameters([True, False]) def test_scan_internal_mut_array(self, jit): def body_fun(_, x): @@ -254,6 +209,23 @@ class MutableArrayTest(jtu.JaxTestCase): self.assertEqual(s, a.sharding) self.assertEqual(s, y.sharding) + def test_explicit_sharding_after_indexing(self): + # https://github.com/jax-ml/jax/issues/26936 + mesh = jax.make_mesh((1, 1), ('x', 'y'), explicit_axes=('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(x_ref): + self.assertEqual(core.get_ty(x_ref).sharding.spec, + core.get_ty(x_ref[...]).sharding.spec) + y = x_ref[...] + 1 + return y + + with jax.sharding.use_mesh(mesh): + x = jnp.zeros((4, 4), jnp.int32, device=sharding) + x_ref = core.mutable_array(x) + y = f(x_ref) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): @@ -371,17 +343,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): f(x_ref, x_ref) - @parameterized.parameters([False, True]) - def test_argument_aliases_custom_vjp_fwd(self, jit): - @jax.custom_vjp - def f(x_ref, y_ref): - ... - f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) - if jit: - f = jax.jit(f) - x_ref = core.mutable_array(0.) - with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): - jax.vjp(f, x_ref, x_ref) + # TODO(mattjj): re-enable test after direct-linearize + # @parameterized.parameters([False, True]) + # def test_argument_aliases_custom_vjp_fwd(self, jit): + # @jax.custom_vjp + # def f(x_ref, y_ref): + # ... + # f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) + # if jit: + # f = jax.jit(f) + # x_ref = core.mutable_array(0.) + # with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): + # jax.vjp(f, x_ref, x_ref) # TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 270707934..f371c431e 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase): return g(x) hlo_text = _get_hlo(f)(2.) - self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text) - self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text) - self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text) + self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) def test_remat_appears_in_hlo(self): @ad_checkpoint.remat diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 35363de8c..987a3aa9d 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -127,6 +127,43 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "ops_test_mgpu", + srcs = [ + "ops_test.py", + ], + disable_configs = [ + "gpu_v100", + "gpu_v100_x32", + "gpu_p100", + "gpu_p100_x32", + "gpu_a100", + "gpu_a100_x32", + ], + enable_backends = [ + "gpu", + ], + enable_configs = [ + "gpu_h100", + "gpu_h100_x32", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "JAX_PALLAS_VERBOSE_ERRORS": "0", + }, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_mosaic_gpu", # build_cleaner: keep + "//jax:pallas_tpu", + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), +) + jax_multiplatform_test( name = "indexing_test", srcs = [ @@ -469,6 +506,24 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "tpu_ragged_paged_attention_test", + srcs = ["tpu_ragged_paged_attention_test.py"], + disable_configs = [ + "tpu_v5p_1x1", + ], + enable_backends = ["tpu"], + shard_count = 24, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "tpu_splash_attention_kernel_test", srcs = [ diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7dae608a4..0a3af26de 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -81,24 +81,37 @@ class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest): class PallasCallTest(PallasTest): - @parameterized.named_parameters( - ("add_one", lambda x: x + 1.), - ("logistic", jax.lax.logistic), - ("exp", jax.lax.exp), - ("square", lambda x: x ** 2), - ("rsqrt", jax.lax.rsqrt), - ("tanh", jax.lax.tanh, 1e-6), + @parameterized.product( + op=[ + lax.neg, + lax.bitwise_not, + lax.logistic, + lax.exp, + lambda x: x**2, + lax.rsqrt, + lax.tanh, + lax.log, + ], + approx_math=[True, False], + thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, unary, rtol=1e-7): + def test_unary_op(self, op, approx_math, thread_semantics): + dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], dtype), + compiler_params=plgpu.GPUCompilerParams( + approx_math=approx_math, thread_semantics=thread_semantics + ), ) def kernel(x_ref, o_ref): - o_ref[...] = unary(x_ref[...]) + o_ref[...] = op(x_ref[...]) - x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol) + x = jnp.arange(256).astype(dtype) + np.testing.assert_allclose( + kernel(x), op(x), rtol=1e-5 if approx_math else 3e-7 + ) @parameterized.product( op=[ @@ -641,18 +654,25 @@ class PallasCallTest(PallasTest): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.parameters(False, True) - def test_rsqrt(self, approx_math): + @parameterized.named_parameters( + ("rsqrt", jax.lax.rsqrt, ), + ("log", jax.lax.log, 5e-7), + ("exp", jax.lax.exp, ), + ("exp2", jax.lax.exp2, 5e-7), + ("logistic", jax.lax.logistic, ), + ("tanh", jax.lax.tanh, 5e-7), + ) + def test_approx_math_unary_op(self, unary_op, rtol=1e-7): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), ) def kernel(x_ref, o_ref): - o_ref[...] = jax.lax.rsqrt(x_ref[...]) + o_ref[...] = unary_op(x_ref[...]) - x = jnp.arange(128).astype(jnp.float32) - np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x)) + x = jnp.arange(128).astype(jnp.float32) / 128 + np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): @@ -706,7 +726,7 @@ class PallasCallTest(PallasTest): shape = (128, 64) size = math.prod(shape) def kernel(x_ref, o_ref): - pl.debug_print("{}", x_ref[...]) + pl.debug_print("prefix {}", x_ref[...]) spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) x = jnp.arange(size, dtype=jnp.float32).reshape(shape) f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) @@ -715,8 +735,8 @@ class PallasCallTest(PallasTest): jax.block_until_ready(f(x)) output = get_output() - results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output) - self.assertLen(results, size) + results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) + self.assertLen(results, size, output) for i, j, v in results: i, j, v = map(int, (i, j, v)) self.assertEqual(v, i * shape[1] + j) @@ -766,7 +786,7 @@ class PallasCallTest(PallasTest): with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) + self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) def test_load_scalar(self): @functools.partial( @@ -1153,29 +1173,38 @@ class PallasCallTest(PallasTest): self.assertEqual(data.count('"name": "store"'), 2) np.testing.assert_array_equal(y, x + x) - @parameterized.parameters( - (jnp.float16, jnp.float16), # Noop - (jnp.int16, jnp.bfloat16), - (jnp.int16, jnp.float16), - (jnp.uint16, jnp.float16), - (jnp.float32, jnp.int32), - (jnp.float32, jnp.uint32), - (jnp.uint32, jnp.int32), - (jnp.int32, jnp.uint32), + @parameterized.product( + dtypes=[ + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ], + thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, in_dtype, out_dtype): + def test_bitcast_convert_type(self, dtypes, thread_semantics): + in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - grid = () - @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + @functools.partial( + pl.pallas_call, + out_shape=out_shape, + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), + ) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) - y = convert(x) - y_ref = jax.lax.bitcast_convert_type(x, out_dtype) - np.testing.assert_array_equal(y, y_ref) + np.testing.assert_array_equal( + convert(x), jax.lax.bitcast_convert_type(x, out_dtype) + ) class PallasCallSm90ATest(PallasSm90ATest): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index daa384cc5..907f4601a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -38,10 +38,15 @@ import jax.numpy as jnp import numpy as np if sys.platform != "win32": - from jax.experimental.pallas import triton as plgpu + try: + from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu + except ImportError: + plgpu_mgpu = None + from jax.experimental.pallas import triton as plgpu_triton from jax.experimental.pallas import tpu as pltpu else: - plgpu = None + plgpu_mgpu = None + plgpu_triton = None pltpu = None try: @@ -58,6 +63,7 @@ import hypothesis.strategies as hps jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) +use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) @@ -98,6 +104,7 @@ _DTYPES_32BIT = ( # TODO(apaszke): Add 8-bit floats. _DTYPES_SUB_32BIT = ( "bfloat16", + "float16", "int16", "int8", "int4", @@ -105,6 +112,9 @@ _DTYPES_SUB_32BIT = ( "uint8", "uint4", "bool", + "float8_e4m3b11fnuz", + "float8_e5m2", + "float8_e4m3fn", ) _DTYPES = (*_DTYPES_32BIT, *_DTYPES_SUB_32BIT) @@ -273,13 +283,27 @@ class PallasBaseTest(jtu.JaxTestCase): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") + if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Mosaic GPU requires capability >= sm90") super().setUp() @classmethod def pallas_call(cls, *args, **kwargs): + if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + assert plgpu_mgpu is not None + compiler_params = plgpu_mgpu.GPUCompilerParams( + thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + ) + kwargs["compiler_params"] = compiler_params + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + def skip_if_mosaic_gpu(self): + if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + self.skipTest("TODO: Mosaic GPU does not support this yet") + class OpsTest(PallasBaseTest): @@ -295,6 +319,8 @@ class OpsTest(PallasBaseTest): ] ) def test_weak_dtype(self, fn, dtype): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), ) @@ -332,6 +358,7 @@ class OpsTest(PallasBaseTest): We don't really expect that the results would be wrong, but rather we want to exercise the lowering rules. """ + self.skip_if_mosaic_gpu() def kernel(x_ref, y_ref, o_ref): x = x_ref[0, 0] @@ -393,6 +420,7 @@ class OpsTest(PallasBaseTest): We don't really expect that the results would be wrong, but rather we want to exercise the lowering rules. """ + self.skip_if_mosaic_gpu() def kernel(x_ref, y_ref, o_ref): x = x_ref[:] @@ -532,6 +560,8 @@ class OpsTest(PallasBaseTest): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): + self.skip_if_mosaic_gpu() + if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") # We want exact equality here to match how JAX lowers to XLA @@ -555,6 +585,12 @@ class OpsTest(PallasBaseTest): @parameterized.product(from_dtype=_DTYPES_32BIT, to_dtype=_DTYPES) @hp.given(hps.data()) def test_cast_from_32bit(self, from_dtype, to_dtype, data): + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: + if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + self.skipTest("Not supported on this hardware") + if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + self.skipTest("Test requires libtpu from 2025/3/8 or later") if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): @@ -568,6 +604,10 @@ class OpsTest(PallasBaseTest): self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}: self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861 + if to_dtype == "float16" and not sut_is_mosaic_gpu: + self.skipTest("float16 is only supported with Mosaic GPU") + if sut_is_mosaic_gpu and to_dtype == "bool": + self.skipTest("Sub-byte types are not yet supported with Mosaic GPU") # XLA does not specify the float->int conversion result for NaNs. elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer)) @@ -599,6 +639,8 @@ class OpsTest(PallasBaseTest): # miss bugs that would be hidden due to exhaustive enumeration being in order. @parameterized.product(from_dtype=_DTYPES_SUB_32BIT, to_dtype=_DTYPES, randomize=(False, True)) def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): @@ -617,6 +659,24 @@ class OpsTest(PallasBaseTest): self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}: self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861 + if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu: + self.skipTest("float16 is only supported with Mosaic GPU") + if sut_is_mosaic_gpu: + unsupported_types = {"bool", "int4", "uint4"} + if to_dtype in unsupported_types or from_dtype in unsupported_types: + self.skipTest("Sub-byte types are not yet supported with Mosaic GPU") + if not randomize: + # TODO(bchetioui): rework the test shapes to make this work. + self.skipTest("Exhaustive tests may run out of SMEM with Mosaic GPU") + if from_dtype in { + "float8_e4m3b11fnuz", + "float8_e5m2", + "float8_e4m3fn", + } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: + if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + self.skipTest("Not supported on this hardware") + if not jtu.if_cloud_tpu_at_least(2025, 3, 9): + self.skipTest("Test requires libtpu from 2025/3/9 or later") from_int = np.issubdtype(np.dtype(from_dtype), np.integer) to_int = np.issubdtype(np.dtype(to_dtype), np.integer) @@ -653,12 +713,21 @@ class OpsTest(PallasBaseTest): else: x = jax.lax.bitcast_convert_type( jnp.arange(1 << from_bitwidth, dtype=from_int_dtype), from_dtype - ).reshape(8, -1) + ) + if sut_is_mosaic_gpu: + # TMA loads only support max 256 elements per dimension, so we make + # sure that all the dimensions don't exceed that. + if x.shape[0] > 256: + x = x.reshape(256, -1) + else: + x = x.reshape(8, -1) else: if randomize: x = random.randint(random.key(234), (16, 16), 0, 1, jnp.int32) != 0 else: - x = jnp.asarray([[False, True], [True, False]], dtype="bool") + x = jnp.tile( + jnp.asarray([[False, True], [True, False]], dtype="bool"), (8, 8) + ) assert x.dtype == jnp.dtype(from_dtype) # XLA does not specify the float->int conversion result for NaNs. if jnp.issubdtype(from_dtype, jnp.floating): @@ -717,6 +786,7 @@ class OpsTest(PallasBaseTest): dtype=(jnp.int32, jnp.int16, jnp.int8), ) def test_scalar_map(self, shape, dtype): + self.skip_if_mosaic_gpu() if pltpu is None: self.skipTest("No TPU module available.") if dtype != jnp.int32 and len(shape) < 2: @@ -754,6 +824,7 @@ class OpsTest(PallasBaseTest): self.assertAllClose(f(x).item(), 10.0) def test_concat_constant(self): + self.skip_if_mosaic_gpu() if pltpu is None: self.skipTest("No TPU module available.") axis = 0 @@ -794,6 +865,8 @@ class OpsTest(PallasBaseTest): for value in values ) def test_sign(self, dtype, value): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -825,6 +898,7 @@ class OpsTest(PallasBaseTest): jnp.int32, ) def test_add_constant(self, dtype): + self.skip_if_mosaic_gpu() shape = (256, 256) @@ -844,6 +918,8 @@ class OpsTest(PallasBaseTest): -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) def test_erf_inv(self, value): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), floatx), @@ -935,6 +1011,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -997,6 +1075,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_elementwise_scalar(self, fn, dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1044,6 +1124,8 @@ class OpsTest(PallasBaseTest): self.assertAllClose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): + self.skip_if_mosaic_gpu() + # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx), @@ -1061,6 +1143,8 @@ class OpsTest(PallasBaseTest): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1079,6 +1163,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) def test_integer_pow(self, y): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), ) @@ -1097,6 +1183,8 @@ class OpsTest(PallasBaseTest): ) ) def test_nextafter(self, dtype, x, y): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1132,6 +1220,8 @@ class OpsTest(PallasBaseTest): ) ) def test_comparison(self, fn, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") @@ -1159,6 +1249,8 @@ class OpsTest(PallasBaseTest): ) ) def test_comparison_scalar(self, fn, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @@ -1188,6 +1280,8 @@ class OpsTest(PallasBaseTest): self.assertArraysEqual(out, expected) def test_isnan(self): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), ) @@ -1220,6 +1314,8 @@ class OpsTest(PallasBaseTest): ("bfloat16", "bfloat16"), ) def test_true_divide(self, dtype, out_dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6): self.skipTest("bfloat16 is not supported on older TPU generations") @@ -1249,6 +1345,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype): + self.skip_if_mosaic_gpu() + if self.INTERPRET: self.skipTest("No lowering in interpret mode") @@ -1286,6 +1384,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_binary(self, f, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") @@ -1309,6 +1409,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_binary_scalar(self, f, dtype): + self.skip_if_mosaic_gpu() + if not jtu.test_device_matches(["tpu"]): self.skipTest("Test only supported on TPU.") if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: @@ -1336,6 +1438,8 @@ class OpsTest(PallasBaseTest): ((8, 16, 2), jnp.int8, 1), ) def test_broadcasted_iota(self, shape, dtype, dimension): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Only 32-bit integer iota supported") @@ -1351,6 +1455,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -1365,7 +1471,7 @@ class OpsTest(PallasBaseTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), ) def kernel(x_ref, o_ref): - o_ref[...] = plgpu.approx_tanh(x_ref[...]) + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) # We upcast to float32 because NumPy <2.0 does not handle custom dtypes @@ -1378,6 +1484,8 @@ class OpsTest(PallasBaseTest): ) def test_elementwise_inline_asm(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented: elementwise_inline_asm_p") @@ -1391,7 +1499,7 @@ class OpsTest(PallasBaseTest): out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), ) def kernel(x_ref, o_ref): - [o_ref[...]] = plgpu.elementwise_inline_asm( + [o_ref[...]] = plgpu_triton.elementwise_inline_asm( "tanh.approx.f16x2 $0, $1;", args=[x_ref[...]], constraints="=r,r", @@ -1403,6 +1511,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) def test_debug_barrier(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented: debug_barrier_p") @@ -1415,16 +1525,18 @@ class OpsTest(PallasBaseTest): ) def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] - plgpu.debug_barrier() + plgpu_triton.debug_barrier() x = jnp.array([4.2, 2.4]).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x) @unittest.skipIf( sys.platform == "win32", - "plgpu.TritonCompilerParams unavailable on Windows", + "plgpu_triton.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Test for TPU is covered in tpu_pallas_test.py") @@ -1435,7 +1547,9 @@ class OpsTest(PallasBaseTest): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) + compiler_params=plgpu_triton.TritonCompilerParams( + num_warps=1, num_stages=1 + ), ) def kernel(x_ref, o_ref): pl.debug_print("It works!") @@ -1449,7 +1563,7 @@ class OpsTest(PallasBaseTest): @unittest.skipIf( sys.platform == "win32", - "plgpu.TritonCompilerParams unavailable on Windows", + "plgpu_triton.TritonCompilerParams unavailable on Windows", ) def test_debug_print_with_values(self): if jtu.test_device_matches(["tpu"]): @@ -1462,7 +1576,9 @@ class OpsTest(PallasBaseTest): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) + compiler_params=plgpu_triton.TritonCompilerParams( + num_warps=1, num_stages=1 + ), ) def kernel(x_ref, o_ref): pl.debug_print("x[0] =", x_ref[0]) @@ -1481,6 +1597,8 @@ class OpsTest(PallasBaseTest): ((64,), (32, 2)), ) def test_reshape(self, in_shape, out_shape): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1512,6 +1630,8 @@ class OpsTest(PallasBaseTest): # fmt: on ) def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + self.skip_if_mosaic_gpu() + # Unsupported implicit dim change: from "32,{0,0},(2,128),-1" to none if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1528,6 +1648,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(f(x), expected) def test_num_programs(self): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), @@ -1542,6 +1664,8 @@ class OpsTest(PallasBaseTest): ) def test_where_broadcasting(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1570,6 +1694,8 @@ class OpsTest(PallasBaseTest): ((), (2, 2), ()), ) def test_broadcast_in_dim(self, in_shape, out_shape, dims): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1617,6 +1743,8 @@ class OpsTest(PallasBaseTest): trans_y=[False, True], ) def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): + self.skip_if_mosaic_gpu() + # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): self.skipTest("Requires libtpu built after 2024-12-19") @@ -1679,6 +1807,8 @@ class OpsTest(PallasBaseTest): block_size=[1, 2, 32, 64, 128], ) def test_masked_load_store(self, size, block_size): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented") @@ -1699,6 +1829,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) def test_masked_oob_load_store_slice(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1723,6 +1855,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): + self.skip_if_mosaic_gpu() + # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, @@ -1735,6 +1869,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(kernel(x), x[::4]) def test_broadcasted_load_store(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Unimplemented primitive: broadcast_to") @@ -1758,6 +1894,8 @@ class OpsTest(PallasBaseTest): ((16, 32), (16, 16)), ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1784,6 +1922,8 @@ class OpsTest(PallasBaseTest): self.fail("Expected exception due to invalid broadcasting") def test_swap(self): + self.skip_if_mosaic_gpu() + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24023 if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU this is only supported in interpret mode") @@ -1807,6 +1947,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out[1], x) def test_masked_swap(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -1830,6 +1972,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) def test_masked_oob_swap_slice(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1871,6 +2015,8 @@ class OpsTest(PallasBaseTest): ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), ) def test_scalar_atomic(self, op, value, numpy_op): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1906,6 +2052,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters((0,), (1,)) def test_array_atomic_add(self, axis): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Unimplemented primitive: broadcast_to") @@ -1946,6 +2094,8 @@ class OpsTest(PallasBaseTest): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1968,6 +2118,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1997,6 +2149,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(False, True) def test_reduce_only_dim(self, use_store): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -2040,6 +2194,8 @@ class OpsTest(PallasBaseTest): ] ]) def test_array_reduce(self, op, dtype, axis): + self.skip_if_mosaic_gpu() + if not isinstance(axis, int): self.skipTest("TODO: tuple axes are not yet supported") @@ -2097,6 +2253,8 @@ class OpsTest(PallasBaseTest): dtype=["float16", "float32", "int32", "uint32"], ) def test_cumsum(self, dtype, axis): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -2134,6 +2292,8 @@ class OpsTest(PallasBaseTest): (-1, jnp.bfloat16), ) def test_triu(self, k, dtype): + self.skip_if_mosaic_gpu() + if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]): # TODO(mvoz): b/376330700 raise unittest.SkipTest('NYI - bf16 select') @@ -2159,6 +2319,8 @@ class OpsTest(PallasBaseTest): (jnp.int32, jnp.uint32), ) def test_bitcast_convert_type(self, in_dtype, out_dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -2176,6 +2338,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(y, y_ref) def test_bitcast_convert_type_scalar(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -2333,6 +2497,26 @@ class PallasPrimitivesTest(PallasBaseTest): wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) + @parameterized.product(approx=[False, True]) + def test_reciprocal(self, approx): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on non-TPU devices") + if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + self.skipTest("Test requires libtpu from 2025/3/8 or later") + shape = (32, 256) + x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + + def kernel(x_ref, o_ref): + o_ref[...] = pl.reciprocal(x_ref[...], approx=approx) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) + )(x) + kwargs = {} + if approx: + kwargs.update(dict(atol=2e-5, rtol=2e-5)) + np.testing.assert_allclose(out, jax.lax.reciprocal(x), **kwargs) + class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bb3826dbf..faa75d455 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -733,6 +733,27 @@ class PallasCallTest(PallasBaseTest): ) self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + @parameterized.parameters(jnp.int8, jnp.uint8) + def test_integer_dot(self, dtype): + if jtu.test_device_matches(["tpu"]) and not jtu.is_device_tpu_at_least(5): + self.skipTest("`int8` dot is only supported on v5 TPUs and newer.") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((32, 64), jnp.int32), + ) + def dot_kernel(x_ref, y_ref, o_ref): + o_ref[()] = pl.dot(x_ref[()], y_ref[()]) + + key0, key1 = random.split(random.key(0)) + # FIXME(cjfj): TPU fails with `uint8` values >= 128. + kwargs = dict(minval=jnp.iinfo(dtype).min, maxval=128, dtype=dtype) + # TODO(cjfj): Investigate why this fails on GPU with `k == 16`. + x = random.randint(key0, (32, 128), **kwargs) + y = random.randint(key1, (128, 64), **kwargs) + expected = jnp.dot(x, y, preferred_element_type=jnp.int32) + self.assertAllClose(dot_kernel(x, y), expected, atol=0.0, rtol=0.0) + def test_dot_with_vector(self): if not jtu.test_device_matches(["gpu"]) or self.INTERPRET: self.skipTest( @@ -2104,13 +2125,15 @@ class PallasOutOfBoundsInterpretTest(PallasBaseTest): # TODO(justinfu): This test has low precision on GPU. Improve precision. if jtu.test_device_matches(["gpu"]): atol = 1e-2 + rtol = 5e-3 else: atol = 1e-5 + rtol = 1e-7 # With a masked matmul implementation, uninitialized values will be # masked before computation. This should return the correct result. with self.subTest('MaskedOutputIsCorrect'): - np.testing.assert_allclose(out, expected, atol=atol) + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) class PallasCheckifyTest(PallasBaseTest): diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index e5fbbdd4b..518c16ed2 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,6 +18,8 @@ To work around https://github.com/jax-ml/jax/issues/25671 , this file contains only tests that use shard_map. """ +import functools + from absl.testing import absltest from absl.testing import parameterized @@ -39,12 +41,16 @@ P = jax.sharding.PartitionSpec class InterpretDistributedTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if jax.device_count() < 4: + self.skipTest(f'requires at least 4 devices, found {jax.device_count()}') - @parameterized.parameters('eager', 'on_wait') - def test_right_permute_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_right_permute_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -61,36 +67,26 @@ class InterpretDistributedTest(jtu.JaxTestCase): right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices)) barrier_sem = pltpu.get_barrier_semaphore() - def _body(ijk): - i, (j, k) = ijk - lax.cond( - (i == 0) | (j == 0), - lambda: pltpu.semaphore_signal( - barrier_sem, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH), - lambda: pltpu.semaphore_signal( - barrier_sem, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH)) - return (i + 1, (j + 1, k + 1)) - lax.while_loop(lambda ijk: ijk[0] < 2, _body, (0, (0, 0))) + pltpu.semaphore_signal( + barrier_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH) + pltpu.semaphore_signal( + barrier_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH) pltpu.semaphore_wait(barrier_sem, 2) - def _body2(i, a): - remote_copy_op = pltpu.make_async_remote_copy( + remote_copy_op = pltpu.make_async_remote_copy( src_ref=input_ref, dst_ref=output_ref, send_sem=send_sem, recv_sem=recv_sem, device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, - ) - remote_copy_op.start() - remote_copy_op.wait() - - return i + 1, a + 1 - _ = lax.scan(_body2, 0, jnp.arange(4.0), unroll=2) + ) + remote_copy_op.start() + remote_copy_op.wait() out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( @@ -111,7 +107,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): grid_spec=grid_spec, compiler_params=pltpu.TPUCompilerParams(collective_id=13), interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( @@ -133,12 +129,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_all_gather_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_all_gather_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P('x', None) mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -230,7 +228,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -254,12 +252,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_all_reduce_sum_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_all_reduce_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -388,7 +388,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -413,12 +413,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_reduce_scatter_sum_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_reduce_scatter_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -670,7 +672,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=True), compiler_params=pltpu.TPUCompilerParams(collective_id=7), )(input_arr)[0] @@ -700,17 +702,19 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) def test_reduce_scatter_sum_with_emit_pipeline_example( - self, dma_execution_mode): + self, dma_execution_mode, detect_races): self.skipTest('requires a patched pallas.emit_pipeline to specify/fake ' 'the TPU generation') if jax.config.jax_enable_x64: self.skipTest('pallas.emit_pipeline + x64 is not currently supported') num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -972,7 +976,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=19), )(input_arr)[0] @@ -1001,6 +1005,95 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) + + def test_race_detection(self): + num_devices = 4 + mesh = jax.sharding.Mesh(np.array(jax.devices()[:4]), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, P('x', None)) + + input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128)) + input_arr = jax.device_put(input_arr, sharding) + + def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): + # Barrier with all devices before doing any DMAs. + barrier_sem = pltpu.get_barrier_semaphore() + @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) + def _(i, _): + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(jnp.int32(i),), + device_id_type=pltpu.DeviceIdType.MESH, + ) + return None + pltpu.semaphore_wait(barrier_sem, num_devices) + + # Send the specified DMAs. + my_id = lax.axis_index('x') + src_dst_ids = src_dst_ids_ref[:] + recv_count = 0 + for i in range(src_dst_ids.shape[0]): + src_id = src_dst_ids[i, 0] + dst_id = src_dst_ids[i, 1] + @pl.when(src_id == my_id) + def _(): + dma = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=o_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(dst_id,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + dma.start() + dma.wait_send() + recv_count += jnp.where(dst_id == my_id, 1, 0) + + # Wait until we have received all DMAs. + @pl.when(recv_count > 0) + def _(): + fake_dma = pltpu.make_async_remote_copy( + src_ref=x_ref.at[pl.ds(0, 8 * recv_count)], + dst_ref=o_ref.at[pl.ds(0, 8 * recv_count)], + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(my_id,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + fake_dma.wait_recv() + + @jax.jit + def run(src_dst_ids): + return shard_map.shard_map( + pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + interpret=mosaic_interpret.TPUInterpretParams( + dma_execution_mode='eager', + detect_races=True, + ), + ), + mesh=mesh, + in_specs=(P(None), P('x', None)), + out_specs=P('x', None), + check_rep=False, + )(src_dst_ids, input_arr) + + run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + + # Racing writes to device 2. + run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) if __name__ == "__main__": diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 21bfb57d6..71e91a697 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -19,11 +19,13 @@ contains only tests that do not use shard_map. """ from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -33,13 +35,14 @@ jax.config.parse_flags_with_absl() class InterpretTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + self.num_devices = jax.device_count() + if self.num_devices > 1: + # Workaround for https://github.com/jax-ml/jax/issues/25671 + self.skipTest(f'requires 1 device, found {self.num_devices}') def test_matmul_example(self): - num_devices = jax.device_count() - if num_devices > 1: - # Workaround for https://github.com/jax-ml/jax/issues/25671 - self.skipTest(f'requires 1 device, found {num_devices}') - def matmul_kernel(x_ref, y_ref, z_ref): z_ref[...] = x_ref[...] @ y_ref[...] @@ -65,30 +68,93 @@ class InterpretTest(jtu.JaxTestCase): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) - def test_dynamic_grid(self): - num_devices = jax.device_count() - if num_devices > 1: - # Workaround for https://github.com/jax-ml/jax/issues/25671 - self.skipTest(f'requires 1 device, found {num_devices}') - - def kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] + def test_dynamic_grid_and_aliasing(self): + def kernel(s_ref, x_ref, o_ref): + o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) @jax.jit - def f(x): + def f(s, x): return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(iters,), - in_specs=(pl.BlockSpec(x.shape, lambda i: (0, 0)),), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(x.shape, lambda i: (0, 0)), + ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), + input_output_aliases={1: 0}, interpret=mosaic_interpret.TPUInterpretParams() - )(x) + )(s, x) + s = jnp.array([1], dtype=jnp.int32) x = jnp.arange(32 * 128.).reshape((32, 128)) - y = f(x) - np.testing.assert_allclose(y, x) + y = f(s, x) + np.testing.assert_allclose(y, x + 1.0) + + @parameterized.parameters('eager', 'on_wait') + def test_race_detection(self, dma_execution_mode): + def kernel_without_race(x_ref, o_ref, t_ref, sem): + copy = pltpu.make_async_copy(x_ref, t_ref, sem) + copy.start() + copy.wait() + o_ref[...] = t_ref[...] + 1.0 + + def kernel_with_race(x_ref, o_ref, t_ref, sem): + copy = pltpu.make_async_copy(x_ref, t_ref, sem) + copy.start() + # This read of t_ref races with the above DMA's write of t_ref. + o_ref[...] = t_ref[...] + 1.0 + copy.wait() + + x = jnp.zeros((8, 128), jnp.float32) + y = pl.pallas_call(kernel_without_race, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ], + interpret=mosaic_interpret.TPUInterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode), + )(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, x + 1.0) + + pl.pallas_call(kernel_with_race, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ], + interpret=mosaic_interpret.TPUInterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode), + )(x).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) + + def test_skip_floating_point_ops(self): + def matmul_kernel(x_ref, y_ref, z_ref): + z_ref[...] = x_ref[...] @ y_ref[...] + + def matmul(x: jax.Array, y: jax.Array): + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + interpret=mosaic_interpret.TPUInterpretParams( + skip_floating_point_ops=True + ), + )(x, y) + + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (1024, 1024)) + y = jax.random.normal(k2, (1024, 1024)) + z = jax.jit(matmul)(x, y) + np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) + + lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") + self.assertNotIn("dot_general", lowered) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py new file mode 100644 index 000000000..cca8e3bc8 --- /dev/null +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -0,0 +1,305 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + ragged_paged_attention, + ref_ragged_paged_attention, + validate_inputs_on_runtime, +) +import jax.numpy as jnp + + +jax.config.parse_flags_with_absl() + + +def ceil_div(x, a): + assert a != 0 + return (x + a - 1) // a + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class PagedAttentionKernelTest(jtu.JaxTestCase): + + def _test_ragged_paged_attention( + self, + seq_lens, # List[(q_len, kv_len)] + num_heads, # [num_q_heads, num_kv_heads] + head_dim, + page_size, + dtype, + num_pages, + *, + num_kv_pages_per_block=8, + num_queries_per_block=64, + vmem_limit_bytes=32 * 1024 * 1024, + max_num_batched_tokens=512, + max_num_seq=8, + ): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Expect TPUv4+") + cu_q_lens = [0] + kv_lens = [] + for q_len, kv_len in seq_lens: + assert q_len <= kv_len + cu_q_lens.append(cu_q_lens[-1] + q_len) + kv_lens.append(kv_len) + + max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) + max_num_seq = max(len(seq_lens), max_num_seq) + max_kv_len = max(kv_lens) + pages_per_seq = ceil_div(max_kv_len, page_size) + pages_per_seq = ( + ceil_div(pages_per_seq, num_kv_pages_per_block) + * num_kv_pages_per_block + ) + num_q_heads, num_kv_heads = num_heads + + cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) + kv_lens = jnp.array(kv_lens, dtype=jnp.int32) + cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) + kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) + prng_key = jax.random.key(1234) + k0, k1, k2, k3 = jax.random.split(prng_key, 4) + q = jax.random.normal( + k0, + (max_num_batched_tokens, num_q_heads, head_dim), + dtype=dtype, + ) + k_pages = jax.random.normal( + k1, + (num_pages, page_size, num_kv_heads, head_dim), + dtype=dtype, + ) + v_pages = jax.random.normal( + k2, + (num_pages, page_size, num_kv_heads, head_dim), + dtype=dtype, + ) + page_indices = jax.random.randint( + k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + ) + + num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) + + validate_inputs_on_runtime( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + ) + + output = ragged_paged_attention( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + )[: cu_q_lens[num_seqs[0]]] + + expected = ref_ragged_paged_attention( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs=num_seqs, + ) + tols = { + "float32": 1e-1, + "bfloat16": 2e-1, + } + tol = tols[jnp.dtype(dtype).name] + self.assertAllClose(output, expected, atol=tol, rtol=tol) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_basic(self, dtype): + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_decode_only(self, dtype): + seq_lens = [ + (1, 18), + (1, 129), + (1, 597), + (1, 122), + (1, 64), + (1, 322), + (1, 463), + (1, 181), + (1, 1107), + (1, 123), + (1, 31), + (1, 18), + (1, 1229), + (1, 229), + (1, 87), + (1, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_prefill_only(self, dtype): + seq_lens = [ + (5, 18), + (15, 129), + (120, 597), + (100, 122), + (21, 64), + (32, 322), + (251, 463), + (40, 181), + (64, 1107), + (99, 123), + (10, 31), + (5, 18), + (3, 1229), + (120, 229), + (9, 87), + (2, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_mixed(self, dtype): + seq_lens = [ + (5, 18), + (1, 129), + (120, 597), + (1, 122), + (1, 64), + (32, 322), + (251, 463), + (1, 181), + (1, 1107), + (99, 123), + (1, 31), + (5, 18), + (3, 1229), + (117, 229), + (1, 87), + (1, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + num_seqs=[1, 5, 16], + # TODO(jevinjiang): Support more num_heads! + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + dtype=[jnp.float32, jnp.bfloat16], + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + ) + def test_ragged_paged_attention_complex( + self, + num_seqs, + num_heads, + dtype, + num_kv_pages_per_block, + num_queries_per_block, + ): + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0cb3f9d28..bd7954d60 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1205,8 +1205,7 @@ class PJitTest(jtu.BufferDonationTestCase): with self.assertRaisesRegex( ValueError, r"One of with_sharding_constraint.*Sharding " - r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " - r"spec=PartitionSpec\(None, 'mdl', None, None\).*\) is only " + r"NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 1"): pjit_f(jnp.array([1, 2, 3])) @@ -2076,7 +2075,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with global_mesh: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x: x)(input_array) def test_array_lower_compile(self): @@ -2177,7 +2176,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y), out_shardings=(NamedSharding(m1, spec), NamedSharding(m2, spec)))(a1, a1) @@ -2192,7 +2191,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit( lambda x, y: (x, y), in_shardings=NamedSharding(m2, spec), @@ -2348,7 +2347,7 @@ class ArrayPjitTest(jtu.JaxTestCase): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( RuntimeError, - r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or' + r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or' r' `None` to in_shardings.*'): pjit(lambda x: x, in_shardings=P('x'))(arr) @@ -2396,7 +2395,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with jtu.create_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation"): + "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y))(uarr, carr) def test_pjit_uncommitted_array_multi_devices(self): @@ -2418,7 +2417,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x of.*\ with shape int.*\[3\] and device ids \[0\].*and " r"argument y of.*\ with shape int.*\[3\] and device ids \[1\].*"): pjit(lambda x, y: (x, y))(a, b) @@ -2430,7 +2429,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x\[0\] of.*\ with shape int.*\[3\] and device ids \[0\].*and " r"argument x\[1\] of.*\ with shape int.*\[3\] and device ids " r"\[1\].*"): @@ -2443,7 +2442,7 @@ class ArrayPjitTest(jtu.JaxTestCase): c = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) - msg = ("Received incompatible devices for pjitted computation. Got " + msg = ("Received incompatible devices for jitted computation. Got " r"argument {} of.* with shape int.*\[3\] and device ids " r"\[0\].*and argument {} of.* with shape int.*\[8,2\] and " r"device ids.*") @@ -2617,9 +2616,9 @@ class ArrayPjitTest(jtu.JaxTestCase): return f(inp1, inp2, inp3) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside pjit with device ids.*"): + r"pjit inside jit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, @@ -6240,6 +6239,37 @@ class ShardingInTypesTest(jtu.JaxTestCase): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) + @jtu.with_user_mesh((4,), ('x',)) + def test_concat_vmap(self, mesh): + @jax.jit + def _f(sharded_array, replicated_array): + def _single_array(a, b): + return jnp.concatenate([a, b], axis=-1) + + _first_vmap = jax.vmap(_single_array, in_axes=(None, 0)) + _second_vmap = jax.vmap(_first_vmap, in_axes=(0, None)) + return jax.vmap(_second_vmap, in_axes=(0, None))(sharded_array, replicated_array) + + np_inp = np.ones((4 * 4, 10, 5, 4)) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + arr2 = jax.device_put( + jnp.ones((10, 5, 3)), NamedSharding(mesh, P())) + + out = _f(arr1, arr2) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, None, None, None))) + + out = _f(arr1, jnp.ones((10, 5, 3))) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, None, None, None))) + + def test_aval_spec_explicit_auto_complete(self): + abstract_mesh = mesh_lib.AbstractMesh( + (('x', 2),), axis_types={AxisTypes.Explicit: 'x'}) + s = NamedSharding(abstract_mesh, P('x')) + out = core.ShapedArray((8, 2), jnp.int32, sharding=s) + self.assertEqual(out.sharding.spec, P('x', None)) + @jtu.with_user_mesh((2, 2), ('x', 'y'), axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) def test_full_user_mode(self, mesh): @@ -6294,6 +6324,34 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.shape, (16, 8, 16)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) + @jtu.with_user_mesh((4,), ('data',)) + def test_intermediate_einsum_auto_complete_spec(self, mesh): + s = NamedSharding(mesh, P('data')) + + shape1 = (8, 32, 2*16) + shape2 = (8, 32, 2, 8) + shape3 = (8, 32, 2, 8) + np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) + np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + + arr1 = jax.device_put(np_inp1, s) + arr2 = jax.device_put(np_inp2, s) + arr3 = jax.device_put(np_inp3, s) + + @jax.jit + def f(x, y, z): + x = jnp.reshape(x, (8, 32, 2, 16)) + out = jnp.einsum('bthD, bthi, bthj->ijD', x, y, z, + out_sharding=P('data')) + self.assertEqual(out.shape, (8, 8, 16)) + self.assertEqual(out.aval.sharding.spec, P('data', None, None)) + return out + + out = f(arr1, arr2, arr3) + self.assertEqual(out.shape, (8, 8, 16)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) + def test_where_with_prng_sharded_inp(self): mesh = jax.sharding.Mesh(jax.devices(), axis_names=['batch']) sharding = jax.sharding.NamedSharding( @@ -6814,31 +6872,6 @@ class ShardingInTypesTest(jtu.JaxTestCase): ' axis_types are `Auto`'): NamedSharding(mesh, P(P.UNCONSTRAINED)) - def test_use_mesh_legacy_mesh_ctx_mgr_mix_error(self): - mesh = jtu.create_mesh((1, 1), ('x', 'y')) - - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jax.jit(lambda x: x)(jnp.arange(8)) - - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jnp.zeros((8, 2), dtype=jnp.int32) - - x = jnp.arange(8) - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) - def test_pspec_einsum_no_context_mesh(self): mesh = jtu.create_mesh((1, 1), ('x', 'y'), axis_types={AxisTypes.Explicit: ('x', 'y')}) @@ -7013,6 +7046,28 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): auto_axes(f, out_shardings=s)(arr) + def test_divisbility_aval_error(self): + abstract_mesh = mesh_lib.AbstractMesh( + (('x', 2),), axis_types={AxisTypes.Explicit: 'x'}) + s = NamedSharding(abstract_mesh, P('x')) + with self.assertRaisesRegex( + ValueError, 'does not evenly divide the dimension size'): + core.ShapedArray((5, 2), jnp.int32, sharding=s) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_scan_unroll(self, mesh): + np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) + carry = jnp.ones((8,), dtype=jnp.float32) + + @jax.jit + def f(carry, xs): + def body(carry, x): + return carry + x, x + return jax.lax.scan(body, carry, xs, unroll=2) + + f(carry, arr) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -7212,7 +7267,7 @@ class PJitErrorTest(jtu.JaxTestCase): xshape = (2, 5, 6) x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation.*"): + ValueError, "Received incompatible devices for jitted computation.*"): f(x) @parameterized.named_parameters( diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 4ad83556f..2fd3a24d3 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -424,6 +424,59 @@ class RooflineTest(jtu.JaxTestCase): ) self.assertDataclassEqual(bwd_results, expected) + @jtu.parameterized.named_parameters( + ("abs", lax.abs, float), + ("acos", lax.acos, float), + ("asin", lax.asin, float), + ("atan", lax.atan, float), + ("cbrt", lax.cbrt, float), + ("ceil", lax.ceil, float), + ("conj", lax.conj, complex), + ("cos", lax.cos, float), + ("cosh", lax.cosh, float), + ("exp", lax.exp, float), + ("expm1", lax.expm1, float), + ("floor", lax.floor, float), + ("imag", lax.imag, complex), + ("integer_pow", lambda a: lax.integer_pow(a, 5), int), + ("is_finite", lax.is_finite, float), + ("log", lax.log, float), + ("log1p", lax.log1p, float), + ("logistic", lax.logistic, float), + ("neg", lax.neg, float), + ("not", lax.bitwise_not, bool), + ("real", lax.real, complex), + ("round", lax.round, float), + ("rsqrt", lax.rsqrt, float), + ("sign", lax.sign, float), + ("sin", lax.sin, float), + ("sinh", lax.sinh, float), + ("sqrt", lax.sqrt, float), + ("square", lax.square, float), + ("tan", lax.tan, float), + ("bessel_i0e", lax.bessel_i0e, float), + ("bessel_i1e", lax.bessel_i1e, float), + ("digamma", lax.digamma, float), + ("erf_inv", lax.erf_inv, float), + ("erf", lax.erf, float), + ("erfc", lax.erfc, float), + ("lgamma", lax.lgamma, float), + ) + def test_unary_ops(self, f, dtype): + data = jnp.zeros((3, 8), dtype=dtype) + out, result = roofline.roofline( + f, + in_specs=(P()), + out_specs=P(), + )(data) + with self.subTest("flops"): + self.assertEqual(result.unfused_flops, 3 * 8) + with self.subTest("hbm_bytes"): + self.assertEqual( + result.unfused_hbm_bytes, + data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, + ) + def test_binary_ops(self): for f in [ lambda a, b: a ^ b, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 8e51b3153..520fd10c9 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2205,20 +2205,23 @@ class ShardMapTest(jtu.JaxTestCase): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): + # manual: 'i', 'j' return x * x def h(x): + # auto: 'j', manual: 'i' return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) @jax.jit def f(x): + # auto: 'i', 'j' return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2814,7 +2817,7 @@ def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]: name, *case = sample_one(rng, make_gen()) if name not in seen: seen.add(name) - yield name, *case + yield case # To sample one test spec, we run the generator, getting back sequences of # options from it and sending in our choices from those options until finally a @@ -2929,7 +2932,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): def make_mesh(mesh_shape): return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) @@ -2938,7 +2941,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) @@ -2947,9 +2950,9 @@ class ShardMapSystematicTest(jtu.JaxTestCase): expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) - @parameterized.named_parameters( - (name + f'_check_rep={check_rep}', *params, check_rep) - for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap) + @parameterized.parameters( + (*params, check_rep) + for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap) for check_rep in [True, False] ) @jax.default_matmul_precision("float32") @@ -2961,7 +2964,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) @jax.default_matmul_precision("float32") def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _): @@ -2980,7 +2983,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): return g(*args) jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): @@ -3003,7 +3006,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): tol = 1e-2 if jtu.test_device_matches(['tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _): diff --git a/tests/version_test.py b/tests/version_test.py index 1036d958f..b78e61ae0 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -24,11 +24,15 @@ import jax from jax._src.lib import check_jaxlib_version from jax._src import test_util as jtu -# This is a subset of the full PEP440 pattern; for example we skip pre & post releases +# This is a subset of the full PEP440 pattern; for example we skip post releases VERSION_PATTERN = re.compile(r""" ^ # start of string (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?: + (?:rc(?P[0-9]+))? # optional rc version; like 'rc1' + | # or + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + )? (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' $ # end of string """, re.VERBOSE) @@ -170,6 +174,18 @@ class JaxVersionTest(unittest.TestCase): self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE="1", + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3") diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b9972c039..6710de12a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,15 +21,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "217a88ec8d4a0b31697e1479a0befb798546eb11" -XLA_SHA256 = "e3b5674e2b1cd485929684ab92dd763cdc62e5ff576efb662331cad5ac000717" +XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240" +XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603" def repo(): tf_http_archive( name = "xla", sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), - urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), + urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), ) # For development, one often wants to make changes to the TF repository as well