mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge branch 'main' into refs-in-vjps
This commit is contained in:
commit
f313a46916
171
.github/workflows/upstream-nightly.yml
vendored
Normal file
171
.github/workflows/upstream-nightly.yml
vendored
Normal file
@ -0,0 +1,171 @@
|
||||
name: CI - with Numpy/Scipy nightly wheels (nightly)
|
||||
# This configures a github action that runs the JAX test suite against nightly development builds
|
||||
# of numpy and scipy, in order to catch issues with new package versions prior to their release.
|
||||
# Unlike our other CI, this is one that we expect to fail frequently, and so we don't run it against
|
||||
# every commit and PR in the repository. Rather, we run it on a schedule, and failures lead to an
|
||||
# issue being created or updated.
|
||||
# Portions of this adapted from https://github.com/pydata/xarray/blob/main/.github/workflows/upstream-dev-ci.yaml
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 12 * * *" # Daily at 12:00 UTC
|
||||
workflow_dispatch: # allows triggering the workflow run manually
|
||||
pull_request: # Automatically trigger on pull requests affecting this file
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**workflows/upstream-nightly.yml'
|
||||
|
||||
jobs:
|
||||
upstream-dev:
|
||||
runs-on: ubuntu-20.04-16core
|
||||
permissions:
|
||||
contents: read
|
||||
checks: write # for upload-artifact
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
outputs:
|
||||
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
|
||||
steps:
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install JAX test requirements
|
||||
run: |
|
||||
pip install -r build/test-requirements.txt
|
||||
pip install pytest-reportlog
|
||||
- name: Install numpy & scipy development versions
|
||||
run: |
|
||||
pip install \
|
||||
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
|
||||
--no-deps \
|
||||
--pre \
|
||||
--upgrade \
|
||||
numpy \
|
||||
scipy
|
||||
- name: Install JAX
|
||||
run: |
|
||||
pip install .[ci]
|
||||
- name: Run tests
|
||||
if: success()
|
||||
id: status
|
||||
env:
|
||||
JAX_NUM_GENERATED_CASES: 1
|
||||
JAX_ENABLE_X64: true
|
||||
JAX_ENABLE_CHECKS: true
|
||||
JAX_SKIP_SLOW_TESTS: true
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
||||
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
||||
pytest -n auto --tb=short -rf --maxfail=20 \
|
||||
--report-log output-${{ matrix.python-version }}-log.jsonl \
|
||||
tests \
|
||||
|| (
|
||||
echo 'ARTIFACTS_AVAILABLE=true' >> $GITHUB_OUTPUT && false
|
||||
)
|
||||
- name: Upload artifacts
|
||||
if: |
|
||||
failure()
|
||||
&& steps.status.outcome == 'failure'
|
||||
&& github.event_name == 'schedule'
|
||||
&& github.repository == 'google/jax'
|
||||
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # ratchet: actions/upload-artifact@v4
|
||||
with:
|
||||
name: output-${{ matrix.python-version }}-log.jsonl
|
||||
path: output-${{ matrix.python-version }}-log.jsonl
|
||||
retention-days: 5
|
||||
|
||||
report:
|
||||
name: report
|
||||
needs: upstream-dev
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
if: |
|
||||
failure()
|
||||
&& github.event_name == 'schedule'
|
||||
&& needs.upstream-dev.outputs.artifacts_availability == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
||||
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- uses: actions/download-artifact@c850b930e6ba138125429b7e5c93fc707a7f8427 # ratchet:actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/workspace/logs
|
||||
- name: install requirements
|
||||
run: |
|
||||
python -m pip install pytest
|
||||
- name: Move all log files into a single directory
|
||||
run: |
|
||||
rsync -a /tmp/workspace/logs/output-*/ ./logs
|
||||
ls -R ./logs
|
||||
cat logs/*.jsonl > pytest-logs.txt
|
||||
python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt
|
||||
- name: Report failures
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const parsed_logs = fs.readFileSync('parsed-logs.txt', 'utf8');
|
||||
const title = "⚠️ Nightly upstream-dev CI failed ⚠️"
|
||||
const workflow_url = `https://github.com/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
const issue_body = `[Workflow Run URL](${workflow_url})\n${parsed_logs}`
|
||||
// Run GraphQL query against GitHub API to find the most recent open issue used for reporting failures
|
||||
const query = `query($owner:String!, $name:String!, $creator:String!, $label:String!){
|
||||
repository(owner: $owner, name: $name) {
|
||||
issues(first: 1, states: OPEN, filterBy: {createdBy: $creator, labels: [$label]}, orderBy: {field: CREATED_AT, direction: DESC}) {
|
||||
edges {
|
||||
node {
|
||||
body
|
||||
id
|
||||
number
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`;
|
||||
const variables = {
|
||||
owner: context.repo.owner,
|
||||
name: context.repo.repo,
|
||||
label: 'CI',
|
||||
creator: "github-actions[bot]"
|
||||
}
|
||||
const result = await github.graphql(query, variables)
|
||||
// If no issue is open, create a new issue,
|
||||
// else update the body of the existing issue.
|
||||
if (result.repository.issues.edges.length === 0) {
|
||||
github.rest.issues.create({
|
||||
owner: variables.owner,
|
||||
repo: variables.name,
|
||||
body: issue_body,
|
||||
title: title,
|
||||
labels: [variables.label]
|
||||
})
|
||||
} else {
|
||||
github.rest.issues.update({
|
||||
owner: variables.owner,
|
||||
repo: variables.name,
|
||||
issue_number: result.repository.issues.edges[0].node.number,
|
||||
body: issue_body
|
||||
})
|
||||
}
|
@ -9,7 +9,7 @@
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: check-merge-conflict
|
||||
@ -26,7 +26,7 @@ repos:
|
||||
files: \.py$
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.2
|
||||
rev: v0.3.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
@ -40,7 +40,7 @@ repos:
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
rev: v1.16.0
|
||||
rev: v1.16.1
|
||||
hooks:
|
||||
- id: jupytext
|
||||
args: [--sync]
|
||||
|
@ -6,7 +6,7 @@
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: "ubuntu-20.04"
|
||||
os: "ubuntu-22.04"
|
||||
tools:
|
||||
python: "3.9"
|
||||
|
||||
|
35
CHANGELOG.md
35
CHANGELOG.md
@ -8,10 +8,25 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.27
|
||||
|
||||
* Changes
|
||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||
the old behavior by transforming the arguments via
|
||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
lowering pass via Triton Python APIs has been removed and the
|
||||
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
|
||||
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
|
||||
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
|
||||
`max` ({jax-issue}`20550`).
|
||||
* The `device()` method of JAX arrays has been removed, after being deprecated
|
||||
since JAX v0.4.21. Use `arr.devices()` instead.
|
||||
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
|
||||
is deprecated; empty inputs to softmax are now supported without setting this.
|
||||
|
||||
|
||||
## jaxlib 0.4.27
|
||||
@ -25,6 +40,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Changes
|
||||
* Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral
|
||||
branch consistent with that of NumPy 2.0.
|
||||
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
|
||||
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
|
||||
changed](https://github.com/google/jax/issues/19085) so that
|
||||
mapping over keys results in random generation only from the first
|
||||
key in the batch.
|
||||
* Docs now use `jax.random.key` for construction of PRNG key arrays
|
||||
rather than `jax.random.PRNGKey`.
|
||||
|
||||
* Deprecations & Removals
|
||||
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
||||
@ -40,6 +62,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
|
||||
* The `jax.experimental.host_callback` module is deprecated.
|
||||
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
|
||||
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
|
||||
new callbacks. See {jax-issue}`#20385` for a discussion.
|
||||
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
||||
that cannot be converted to a JAX array now results in an exception.
|
||||
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
|
||||
@ -47,14 +71,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* The previously-deprecated imports `jax.interpreters.ad.config` and
|
||||
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
|
||||
and `jax.extend.source_info_util` instead.
|
||||
* JAX export does not support anymore older serialization version. Version 9
|
||||
* JAX export does not support older serialization versions anymore. Version 9
|
||||
has been supported since October 27th, 2023 and has become the default
|
||||
since February 1, 2024.
|
||||
See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
|
||||
This change could break clients that set a specific
|
||||
JAX serialization version lower than 9.
|
||||
|
||||
|
||||
## jaxlib 0.4.26 (April 3, 2024)
|
||||
|
||||
* Changes
|
||||
@ -131,8 +154,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
cannot interact, e.g., in arithmetic operations.
|
||||
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
|
||||
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
|
||||
The scope of a symbolic expression `e` can be read with `e.scope` and passed in
|
||||
to the above functions to direct them to construct symbolic expressions in
|
||||
The scope of a symbolic expression `e` can be read with `e.scope` and passed
|
||||
into the above functions to direct them to construct symbolic expressions in
|
||||
a given scope.
|
||||
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
|
||||
* simplified and faster equality comparisons, where we consider two symbolic dimensions
|
||||
@ -313,7 +336,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Bug fixes
|
||||
* Only process 0 in a multicontroller distributed JAX program will write
|
||||
persistent compilation cache entries. This fixes write contention if the
|
||||
cache is placed on a network filesystem such as GCS.
|
||||
cache is placed on a network file system such as GCS.
|
||||
* The version check for cusolver and cufft no longer considers the patch
|
||||
versions when determining if the installed version of these libraries is at
|
||||
least as new as the versions against which JAX was built.
|
||||
@ -1441,7 +1464,7 @@ Changes:
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#8678`).
|
||||
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
|
||||
|
@ -33,9 +33,7 @@ from jax.experimental import multihost_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
|
@ -15,12 +15,12 @@
|
||||
|
||||
import google_benchmark as benchmark
|
||||
|
||||
from jax import config
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax.experimental import export
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@benchmark.register
|
||||
|
File diff suppressed because it is too large
Load Diff
528
docs/Custom_Operation_for_GPUs.py
Normal file
528
docs/Custom_Operation_for_GPUs.py
Normal file
@ -0,0 +1,528 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from build import gpu_ops
|
||||
from jax import core, dtypes
|
||||
from jax.core import ShapedArray
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.interpreters import batching, mlir, xla
|
||||
from jax.interpreters.mlir import ir
|
||||
from jax.lib import xla_client
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
||||
from jaxlib.hlo_helpers import custom_call
|
||||
from jax._src import dispatch
|
||||
|
||||
|
||||
######################################################################
|
||||
# Created Primitives for unsharded RMS norm reference implementation #
|
||||
######################################################################
|
||||
|
||||
# Create _rms_norm_fwd_p for forward operation.
|
||||
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
|
||||
_rms_norm_fwd_p.multiple_results = True
|
||||
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
|
||||
|
||||
|
||||
def rms_norm_fwd(x, weight, eps=1e-05):
|
||||
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
|
||||
return output, (invvar, x, weight)
|
||||
|
||||
|
||||
# Create _rms_norm_bwd_p for backward operation.
|
||||
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
|
||||
_rms_norm_bwd_p.multiple_results = True
|
||||
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
|
||||
|
||||
|
||||
def rms_norm_bwd(eps, res, g):
|
||||
invvar, x, weight = res
|
||||
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
|
||||
g, invvar, x, weight, eps=eps
|
||||
)
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
####################
|
||||
# Lowering to MLIR #
|
||||
####################
|
||||
|
||||
|
||||
# Register functions defined in gpu_ops as custom call target for GPUs
|
||||
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
||||
|
||||
|
||||
def element_type_to_descriptor_type_mapping(element_type):
|
||||
_element_type_to_descriptor_type_mapping = {
|
||||
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
|
||||
ir.F16Type.get(): gpu_ops.ElementType.F16,
|
||||
ir.F32Type.get(): gpu_ops.ElementType.F32,
|
||||
ir.F64Type.get(): gpu_ops.ElementType.F64,
|
||||
}
|
||||
return _element_type_to_descriptor_type_mapping.get(element_type)
|
||||
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
|
||||
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
x_type = ir.RankedTensorType(x.type)
|
||||
x_shape = x_type.shape
|
||||
w_type = ir.RankedTensorType(weight.type)
|
||||
w_shape = w_type.shape
|
||||
iv_element_type = (
|
||||
ir.F32Type.get()
|
||||
if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
|
||||
else x_type.element_type
|
||||
)
|
||||
|
||||
n2 = math.prod(w_shape)
|
||||
n1 = math.prod(x_shape) // n2
|
||||
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
element_type_to_descriptor_type_mapping(x_type.element_type),
|
||||
element_type_to_descriptor_type_mapping(w_type.element_type),
|
||||
0, # unused
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_forward_affine_mixed_dtype",
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get((n1,), iv_element_type),
|
||||
],
|
||||
operands=[x, weight],
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, (n1,)),
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
_rms_norm_fwd_p,
|
||||
_rms_norm_fwd_cuda_lowering,
|
||||
platform="gpu",
|
||||
)
|
||||
|
||||
|
||||
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
x_type = ir.RankedTensorType(x.type)
|
||||
x_shape = x_type.shape
|
||||
w_type = ir.RankedTensorType(weight.type)
|
||||
w_shape = w_type.shape
|
||||
iv_type = ir.RankedTensorType(invvar.type)
|
||||
|
||||
n2 = reduce(lambda x, y: x * y, w_shape)
|
||||
n1 = reduce(lambda x, y: x * y, x_shape) // n2
|
||||
|
||||
part_grad_shape = ctx.avals_out[-1].shape
|
||||
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
element_type_to_descriptor_type_mapping(x_type.element_type),
|
||||
element_type_to_descriptor_type_mapping(w_type.element_type),
|
||||
part_grad_shape[0],
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_backward_affine",
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, x_type.element_type),
|
||||
ir.RankedTensorType.get(w_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
|
||||
],
|
||||
operands=[grad_output, invvar, x, weight],
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
_rms_norm_bwd_p,
|
||||
_rms_norm_bwd_cuda_lowering,
|
||||
platform="gpu",
|
||||
)
|
||||
|
||||
|
||||
#######################
|
||||
# Abstract evaluation #
|
||||
#######################
|
||||
|
||||
|
||||
def _rms_norm_fwd_abstract(x, weight, eps):
|
||||
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
|
||||
iv_dtype = dtypes.canonicalize_dtype(x.dtype)
|
||||
if iv_dtype in [jnp.float16, jnp.bfloat16]:
|
||||
iv_dtype = jnp.float32
|
||||
n2 = math.prod(weight.shape)
|
||||
n1 = math.prod(x.shape) // n2
|
||||
return (
|
||||
ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output
|
||||
ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar
|
||||
)
|
||||
|
||||
|
||||
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
|
||||
|
||||
|
||||
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
|
||||
iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
|
||||
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
|
||||
x_dtype = dtypes.canonicalize_dtype(x.dtype)
|
||||
n2 = reduce(lambda x, y: x * y, weight.shape)
|
||||
n1 = reduce(lambda x, y: x * y, x.shape) // n2
|
||||
part_grad_shape = (16, n2)
|
||||
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
|
||||
assert grad_output.shape == x.shape
|
||||
assert invvar.shape == (n1,)
|
||||
assert (
|
||||
iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
|
||||
)
|
||||
assert grad_output.named_shape == x.named_shape
|
||||
weight_named_shape = (
|
||||
weight.named_shape if weight.named_shape else grad_output.named_shape
|
||||
)
|
||||
return (
|
||||
ShapedArray(
|
||||
x.shape, x_dtype, named_shape=x.named_shape
|
||||
), # grad input
|
||||
ShapedArray(
|
||||
weight.shape, w_dtype, named_shape=weight_named_shape
|
||||
), # grad weight
|
||||
ShapedArray(
|
||||
part_grad_shape, iv_dtype, named_shape=weight_named_shape
|
||||
), # part grad
|
||||
)
|
||||
|
||||
|
||||
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
|
||||
|
||||
|
||||
#######################################
|
||||
# Top-level interface with custom vjp #
|
||||
#######################################
|
||||
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||
def rms_norm(x, weight, eps=1e-05):
|
||||
output, _ = rms_norm_fwd(x, weight, eps=eps)
|
||||
return output
|
||||
|
||||
|
||||
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
|
||||
|
||||
###########################################################
|
||||
# Create primitives for RMS norm with custom_partitioning #
|
||||
###########################################################
|
||||
|
||||
def _check_valid_batch_dims(bdims):
|
||||
"""
|
||||
Assert out non-supported bath dims
|
||||
"""
|
||||
for dim in bdims:
|
||||
assert dim in [0, None], \
|
||||
"Currently only support batch_dim in [0, None], " \
|
||||
f"but got {dim=}"
|
||||
|
||||
def register_primitive(cls):
|
||||
"""
|
||||
register jax primitive
|
||||
|
||||
The order of calls. Each operation is composed of two primitives: Inner and Outer.
|
||||
|
||||
Inner, only the basic to wrap the custom_call itself.
|
||||
- impl to XLA custom_call in C.
|
||||
- abstract to know the static shapes
|
||||
- lower to StableHLO XLA custom_call.
|
||||
Outer, mostly all the rest:
|
||||
- impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
|
||||
- abstract: same
|
||||
- lower to StableHLO custom_p. (XLA will call the python callback from it)
|
||||
- custom_p
|
||||
- vmap: could be added here.
|
||||
VJP is based on Outer, but not handled in this function.
|
||||
"""
|
||||
|
||||
def name_of_wrapper_p():
|
||||
return cls.name + "_wrapper"
|
||||
|
||||
inner_p = core.Primitive(cls.name)
|
||||
dispatch.prim_requires_devices_during_lowering.add(inner_p)
|
||||
inner_p.multiple_results = cls.multiple_results
|
||||
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
|
||||
inner_p.def_abstract_eval(cls.abstract)
|
||||
mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
|
||||
cls.inner_primitive = inner_p
|
||||
|
||||
outer_p = core.Primitive(name_of_wrapper_p())
|
||||
dispatch.prim_requires_devices_during_lowering.add(outer_p)
|
||||
outer_p.multiple_results = cls.multiple_results
|
||||
outer_p.def_impl(cls.impl)
|
||||
outer_p.def_abstract_eval(cls.abstract)
|
||||
batching.primitive_batchers[outer_p] = cls.batcher
|
||||
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
|
||||
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
|
||||
partition=cls.partition)
|
||||
mlir.register_lowering(outer_p,
|
||||
mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
|
||||
cls.outer_primitive = outer_p
|
||||
|
||||
|
||||
class RmsNormFwdClass:
|
||||
name = "rms_forward_affine_mixed_dtype"
|
||||
multiple_results = True
|
||||
impl_static_args = (2,) # eps
|
||||
inner_primitive = None
|
||||
outer_primitive = None
|
||||
|
||||
@staticmethod
|
||||
def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument
|
||||
return _rms_norm_fwd_abstract(x_aval, gamma_aval, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def lowering(ctx, x, gamma, *, eps):
|
||||
return _rms_norm_fwd_cuda_lowering(ctx, x, gamma, eps)
|
||||
|
||||
@staticmethod
|
||||
def impl(x, gamma, eps):
|
||||
assert RmsNormFwdClass.inner_primitive is not None
|
||||
out, rsigma = RmsNormFwdClass.inner_primitive.bind(x, gamma, eps=eps)
|
||||
return out, rsigma
|
||||
|
||||
@staticmethod
|
||||
def batcher(batched_args, batch_dims, *, eps):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
assert RmsNormFwdClass.outer_primitive is not None
|
||||
x, gamma = batched_args
|
||||
x_bdim, _ = batch_dims
|
||||
|
||||
out_bdims = x_bdim, x_bdim
|
||||
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
|
||||
|
||||
@staticmethod
|
||||
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.core.ShapedArray]):
|
||||
del eps, result_infos # Not needed for this example.
|
||||
x_info, weight_info = arg_infos
|
||||
assert len(x_info.shape) == 3
|
||||
assert len(weight_info.shape) == 2
|
||||
# partition() will force all dims to be replicated except the
|
||||
# first dim of x that will be kept as is.
|
||||
x_spec = arg_infos[0].sharding.spec
|
||||
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
|
||||
return (output_sharding, invvar_sharding)
|
||||
|
||||
@staticmethod
|
||||
def partition(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
|
||||
del result_infos # Not needed for this example.
|
||||
x_info, weight_info = arg_infos
|
||||
assert len(x_info.shape) == 3
|
||||
assert len(weight_info.shape) == 2
|
||||
x_spec = arg_infos[0].sharding.spec
|
||||
# We only support sharding on the batch dimensions.
|
||||
# Force sharding on all others dimensions with None.
|
||||
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
|
||||
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
|
||||
output_shardings = (arg_shardings[0], invvar_sharding)
|
||||
# Sharded_impl only accepts positional arugments
|
||||
# And they should be Jax traceable variables
|
||||
impl = partial(RmsNormFwdClass.impl, eps=eps)
|
||||
|
||||
return mesh, impl, output_shardings, arg_shardings
|
||||
|
||||
register_primitive(RmsNormFwdClass)
|
||||
|
||||
class RmsNormBwdClass:
|
||||
name = "rms_norm_bwd"
|
||||
multiple_results = True
|
||||
impl_static_args = (4,) # eps
|
||||
inner_primitive = None
|
||||
outer_primitive = None
|
||||
|
||||
@staticmethod
|
||||
def abstract(grad_output, invvar, x, weight, eps): # pylint: disable=unused-argument
|
||||
return _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps)
|
||||
|
||||
@staticmethod
|
||||
def lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
return _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps)
|
||||
|
||||
@staticmethod
|
||||
def impl(grad_output, invvar, x, weight, eps):
|
||||
assert RmsNormBwdClass.inner_primitive is not None
|
||||
gx, gw, part_grad = RmsNormBwdClass.inner_primitive.bind(grad_output, invvar, x, weight, eps=eps)
|
||||
return gx, gw, part_grad
|
||||
|
||||
@staticmethod
|
||||
def batcher(batched_args, batch_dims, *, eps):
|
||||
# TODO: Add to the tutorial!
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
assert RmsNormBwdClass.outer_primitive is not None
|
||||
x, gamma = batched_args
|
||||
x_bdim, _ = batch_dims
|
||||
|
||||
out_bdims = x_bdim, x_bdim
|
||||
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
|
||||
|
||||
@staticmethod
|
||||
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.core.ShapedArray]):
|
||||
del eps, result_infos # Not needed for this example.
|
||||
g_info, invvar_info, x_info, weight_info = arg_infos
|
||||
assert len(g_info.shape) == 3
|
||||
assert len(invvar_info.shape) == 1
|
||||
assert len(x_info.shape) == 3
|
||||
assert len(weight_info.shape) == 2
|
||||
# partition() will force all dims to be replicated except the batch dimension.
|
||||
x_spec = x_info.sharding.spec
|
||||
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
|
||||
return (output_sharding, invvar_sharding, output_sharding, )
|
||||
|
||||
@staticmethod
|
||||
def partition(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
|
||||
del result_infos # Not needed for this example.
|
||||
g_info, invvar_info, x_info, weight_info = arg_infos
|
||||
assert len(g_info.shape) == 3
|
||||
assert len(invvar_info.shape) == 1
|
||||
assert len(x_info.shape) == 3
|
||||
assert len(weight_info.shape) == 2
|
||||
|
||||
# We only support sharding on the batch dimensions.
|
||||
# Force sharding on all others dimensions with None.
|
||||
# Also force gx, x and invvar to have the same batch sharding/replication.
|
||||
x_spec = x_info.sharding.spec
|
||||
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
|
||||
NamedSharding(mesh, PartitionSpec(x_spec[0],)),
|
||||
NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
|
||||
NamedSharding(mesh, PartitionSpec(None, None)))
|
||||
|
||||
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
|
||||
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
|
||||
|
||||
|
||||
# Sharded_impl only accepts positional arugments
|
||||
# And they should be Jax traceable variables
|
||||
def sharded_impl(g, invvar, x, weight):
|
||||
grad_input, grad_weight, part_grad = RmsNormBwdClass.impl(
|
||||
g, invvar, x, weight, eps=eps
|
||||
)
|
||||
# We need to sum the weight gradient from all partition.
|
||||
# when the input is sharded and weights are replicated
|
||||
global_weight = grad_weight
|
||||
if x_spec[0]:
|
||||
global_weight = jax.lax.psum(grad_weight, x_spec[0])
|
||||
return grad_input, global_weight, part_grad
|
||||
return mesh, sharded_impl, output_shardings, arg_shardings
|
||||
|
||||
register_primitive(RmsNormBwdClass)
|
||||
|
||||
def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
|
||||
output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
|
||||
return output, (invvar, x, weight)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
||||
def custom_p_rms_norm(x, weight, eps=1e-05):
|
||||
output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
|
||||
return output
|
||||
|
||||
def custom_p_rms_norm_bwd(eps, res, g):
|
||||
invvar, x, weight = res
|
||||
grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
|
||||
g, invvar, x, weight, eps=eps)
|
||||
return grad_input, grad_weight
|
||||
|
||||
custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd)
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
|
||||
|
||||
import jax
|
||||
|
||||
per_core_batch_size = 4
|
||||
seq_len = 512
|
||||
emb_dim = 512
|
||||
assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?"
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.float16,
|
||||
)
|
||||
norm_shape = x.shape[-2:]
|
||||
weight = jnp.ones(norm_shape, dtype=jnp.float16)
|
||||
|
||||
|
||||
def ref_loss(x, weight):
|
||||
predictions = rms_norm(x, weight)
|
||||
return -jnp.mean(predictions**2)
|
||||
|
||||
ref_out = jax.grad(ref_loss, argnums=(0, 1))(x, weight)
|
||||
|
||||
def custom_p_loss(x, weight):
|
||||
predictions = custom_p_rms_norm(x, weight)
|
||||
return -jnp.mean(predictions**2)
|
||||
|
||||
with Mesh(jax.local_devices(), ("x",)):
|
||||
def run_and_verify(loss):
|
||||
pjitted = pjit(
|
||||
jax.grad(loss, argnums=(0, 1)),
|
||||
# Shard x by batch dimension and replicate weight on all devices.
|
||||
in_shardings=(
|
||||
PartitionSpec("x", None, None),
|
||||
PartitionSpec(None, None),
|
||||
),
|
||||
# Shard the output by batch dimension and replicate weight grad on all devices.
|
||||
out_shardings=(
|
||||
PartitionSpec("x", None, None),
|
||||
PartitionSpec(None, None),
|
||||
),
|
||||
)
|
||||
hlo = pjitted.lower(x, weight).compile().as_text()
|
||||
out = pjitted(x, weight)
|
||||
print(hlo)
|
||||
assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
|
||||
if "all-gather-start" in hlo:
|
||||
print("NOT OPTIMIZED, ALL_GATHER in the graph!")
|
||||
return out
|
||||
|
||||
custom_p_out = run_and_verify(custom_p_loss)
|
||||
|
||||
|
||||
for r, o in zip(ref_out, custom_p_out):
|
||||
print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
|
@ -6,7 +6,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -20,7 +20,7 @@
|
||||
# extension: .py
|
||||
# format_name: light
|
||||
# format_version: '1.5'
|
||||
# jupytext_version: 1.16.0
|
||||
# jupytext_version: 1.16.1
|
||||
# kernelspec:
|
||||
# display_name: Python 3
|
||||
# name: python3
|
||||
|
13
docs/build_custom_gpu.sh
Normal file
13
docs/build_custom_gpu.sh
Normal file
@ -0,0 +1,13 @@
|
||||
python -m pip install pybind11==2.10.1
|
||||
mkdir -p build
|
||||
touch build/__init__.py
|
||||
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
|
||||
python_executable=$(python -c 'import sys; print(sys.executable)')
|
||||
#python_include_path=$(python -c 'from distutils.sysconfig import get_python_inc;print(get_python_inc())')
|
||||
echo pybind_include_path=$pybind_include_path
|
||||
echo python_executable=$python_executable
|
||||
|
||||
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
|
||||
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}3-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
|
||||
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}3-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
|
||||
strip build/gpu_ops$(${python_executable}3-config --extension-suffix)
|
@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.
|
||||
|
||||
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
### Example(s)
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
import jax
|
||||
jax.config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
|
||||
You can disable JIT-compilation by:
|
||||
* setting the `JAX_DISABLE_JIT=True` environment variable;
|
||||
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_disable_jit", True)
|
||||
import jax
|
||||
jax.config.update("jax_disable_jit", True)
|
||||
|
||||
def f(x):
|
||||
y = jnp.log(x)
|
||||
|
@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.update("jax_debug_nans", True)
|
||||
import jax
|
||||
jax.config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
|
@ -206,7 +206,7 @@ python build/build.py --configure_only
|
||||
You may pass additional options to `build.py` to configure the build; see the
|
||||
`jaxlib` build documentation for details.
|
||||
|
||||
By default the Bazel build runs the JAX tests using `jaxlib` built form source.
|
||||
By default the Bazel build runs the JAX tests using `jaxlib` built from source.
|
||||
To run JAX tests, run:
|
||||
|
||||
```
|
||||
|
30
docs/faq.rst
30
docs/faq.rst
@ -413,7 +413,7 @@ speed of code using JAX:
|
||||
use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see
|
||||
`Double (64 bit) precision`_) for a fair comparison.
|
||||
4. **Transferring data between CPUs and accelerators takes time.** If you only
|
||||
want to measure the how long it takes to evaluate a function, you may want to
|
||||
want to measure how long it takes to evaluate a function, you may want to
|
||||
transfer data to the device on which you want to run it first (see
|
||||
:ref:`faq-data-placement`).
|
||||
|
||||
@ -814,6 +814,32 @@ computation at runtime. For example:
|
||||
For more information on runtime callbacks and examples of their use,
|
||||
see `External callbacks in JAX`_.
|
||||
|
||||
Why do some CUDA libraries fail to load/initialize?
|
||||
---------------------------------------------------
|
||||
|
||||
When resolving dynamic libraries, JAX uses the usual `dynamic linker search pattern`_.
|
||||
JAX sets :code:`RPATH` to point to the JAX-relative location of the
|
||||
pip-installed NVIDIA CUDA packages, preferring them if installed. If :code:`ld.so`
|
||||
cannot find your CUDA runtime libraries along its usual search path, then you
|
||||
must include the paths to those libraries explicitly in :code:`LD_LIBRARY_PATH`.
|
||||
The easiest way to ensure your CUDA files are discoverable is to simply install
|
||||
the :code:`nvidia-*-cu12` pip packages, which are included in the standard
|
||||
:code:`jax[cuda_12]` install option.
|
||||
|
||||
Occasionally, even when you have ensured that your runtime libraries are discoverable,
|
||||
there may still be some issues with loading or initializing them. A common cause of
|
||||
such issues is simply having insufficient memory for CUDA library initialization at
|
||||
runtime. This sometimes occurs because JAX will pre-allocate too large of a chunk of
|
||||
currently available device memory for faster execution, occasionally resulting in
|
||||
insufficient memory being left available for runtime CUDA library initialization.
|
||||
|
||||
This is especially likely when running multiple JAX instances, running JAX in
|
||||
tandem with TensorFlow which performs its own pre-allocation, or when running
|
||||
JAX on a system where the GPU is being heavily utilized by other processes. When
|
||||
in doubt, try running the program again with reduced pre-allocation, either by
|
||||
reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`,
|
||||
or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please
|
||||
see the page on `JAX GPU memory allocation`_.
|
||||
|
||||
.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
|
||||
.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
@ -822,3 +848,5 @@ see `External callbacks in JAX`_.
|
||||
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
|
||||
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
|
||||
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
|
||||
.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
|
||||
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html
|
@ -53,7 +53,7 @@ JAX Glossary of Terms
|
||||
|
||||
pytree
|
||||
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more
|
||||
general containers of array values in a uniform way. Refer to {ref}`working-with-pytrees`
|
||||
general containers of array values in a uniform way. Refer to :ref:`working-with-pytrees`
|
||||
for a more detailed discussion.
|
||||
|
||||
reverse-mode autodiff
|
||||
|
45
docs/gpu_ops/gpu_ops.cpp
Normal file
45
docs/gpu_ops/gpu_ops.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "kernels.h"
|
||||
#include "pybind11_kernel_helpers.h"
|
||||
|
||||
namespace {
|
||||
pybind11::dict RMSNormRegistrations() {
|
||||
pybind11::dict dict;
|
||||
dict["rms_forward_affine_mixed_dtype"] =
|
||||
gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
|
||||
dict["rms_backward_affine"] =
|
||||
gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(gpu_ops, m) {
|
||||
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
|
||||
m.def("create_rms_norm_descriptor",
|
||||
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
|
||||
gpu_ops::ElementType w_type, int part_grad_size) {
|
||||
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
|
||||
n1, n2, eps, x_type, w_type, part_grad_size});
|
||||
});
|
||||
|
||||
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
|
||||
.value("BF16", gpu_ops::ElementType::BF16)
|
||||
.value("F16", gpu_ops::ElementType::F16)
|
||||
.value("F32", gpu_ops::ElementType::F32)
|
||||
.value("F64", gpu_ops::ElementType::F64);
|
||||
|
||||
}
|
||||
} // namespace
|
64
docs/gpu_ops/kernel_helpers.h
Normal file
64
docs/gpu_ops/kernel_helpers.h
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header is not specific to our application and you'll probably want
|
||||
// something like this for any extension you're building. This includes the
|
||||
// infrastructure needed to serialize descriptors that are used with the
|
||||
// "opaque" parameter of the GPU custom call. In our example we'll use this
|
||||
// parameter to pass the size of our problem.
|
||||
|
||||
#ifndef _GPU_OPS_KERNEL_HELPERS_H_
|
||||
#define _GPU_OPS_KERNEL_HELPERS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#define JAX_APEX_WARP_SIZE 32
|
||||
|
||||
namespace gpu_ops {
|
||||
|
||||
// https://en.cppreference.com/w/cpp/numeric/bit_cast
|
||||
template <class To, class From>
|
||||
typename std::enable_if<sizeof(To) == sizeof(From) &&
|
||||
std::is_trivially_copyable<From>::value &&
|
||||
std::is_trivially_copyable<To>::value,
|
||||
To>::type
|
||||
bit_cast(const From &src) noexcept {
|
||||
static_assert(std::is_trivially_constructible<To>::value,
|
||||
"This implementation additionally requires destination type to "
|
||||
"be trivially constructible");
|
||||
|
||||
To dst;
|
||||
memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
|
||||
return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid opaque object size");
|
||||
}
|
||||
return bit_cast<const T *>(opaque);
|
||||
}
|
||||
|
||||
} // namespace gpu_ops
|
||||
|
||||
#endif
|
44
docs/gpu_ops/kernels.h
Normal file
44
docs/gpu_ops/kernels.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef _GPU_OPS_KERNELS_H_
|
||||
#define _GPU_OPS_KERNELS_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace gpu_ops {
|
||||
|
||||
enum ElementType { BF16, F16, F32, F64 };
|
||||
|
||||
struct RMSNormDescriptor {
|
||||
int n1;
|
||||
int n2;
|
||||
double eps;
|
||||
ElementType x_type;
|
||||
ElementType w_type;
|
||||
int part_grad_size;
|
||||
};
|
||||
|
||||
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
|
||||
const char *opaque,
|
||||
std::size_t opaque_len);
|
||||
void rms_backward_affine(cudaStream_t stream, void **buffers,
|
||||
const char *opaque, std::size_t opaque_len);
|
||||
} // namespace gpu_ops
|
||||
|
||||
#endif
|
41
docs/gpu_ops/pybind11_kernel_helpers.h
Normal file
41
docs/gpu_ops/pybind11_kernel_helpers.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header extends kernel_helpers.h with the pybind11 specific interface to
|
||||
// serializing descriptors. It also adds a pybind11 function for wrapping our
|
||||
// custom calls in a Python capsule. This is separate from kernel_helpers so
|
||||
// that the CUDA code itself doesn't include pybind11. I don't think that this
|
||||
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
|
||||
|
||||
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
||||
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "kernel_helpers.h"
|
||||
|
||||
namespace gpu_ops {
|
||||
|
||||
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
|
||||
return pybind11::bytes(PackDescriptorAsString(descriptor));
|
||||
}
|
||||
|
||||
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
|
||||
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
} // namespace gpu_ops
|
||||
|
||||
#endif
|
970
docs/gpu_ops/rms_norm_kernels.cu
Normal file
970
docs/gpu_ops/rms_norm_kernels.cu
Normal file
@ -0,0 +1,970 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "kernel_helpers.h"
|
||||
#include "kernels.h"
|
||||
#include "stdio.h"
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace {
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, \
|
||||
NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case gpu_ops::ElementType::F64: { \
|
||||
using scalar_t_in = double; \
|
||||
using accscalar_t = double; \
|
||||
switch (TYPEOUT) { \
|
||||
case gpu_ops::ElementType::F64: { \
|
||||
using scalar_t_out = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F32: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F16: { \
|
||||
using scalar_t_out = __half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::BF16: { \
|
||||
using scalar_t_out = __nv_bfloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
break; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F32: { \
|
||||
using scalar_t_in = float; \
|
||||
using accscalar_t = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case gpu_ops::ElementType::F64: { \
|
||||
using scalar_t_out = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F32: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F16: { \
|
||||
using scalar_t_out = __half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::BF16: { \
|
||||
using scalar_t_out = __nv_bfloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
break; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F16: { \
|
||||
using scalar_t_in = __half; \
|
||||
using accscalar_t = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case gpu_ops::ElementType::F64: { \
|
||||
using scalar_t_out = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F32: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F16: { \
|
||||
using scalar_t_out = __half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::BF16: { \
|
||||
using scalar_t_out = __nv_bfloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
break; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::BF16: { \
|
||||
using scalar_t_in = __nv_bfloat16; \
|
||||
using accscalar_t = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case gpu_ops::ElementType::F64: { \
|
||||
using scalar_t_out = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F32: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::F16: { \
|
||||
using scalar_t_out = __half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case gpu_ops::ElementType::BF16: { \
|
||||
using scalar_t_out = __nv_bfloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
break; \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
break; \
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
|
||||
count = count + U(1);
|
||||
U delta = curr - mu;
|
||||
U lmean = mu + delta / count;
|
||||
mu = lmean;
|
||||
U delta2 = curr - lmean;
|
||||
sigma2 = sigma2 + delta * delta2;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
|
||||
U &mu, U &sigma2, U &count) {
|
||||
U delta = muB - mu;
|
||||
U nA = count;
|
||||
U nB = countB;
|
||||
count = count + countB;
|
||||
U nX = count;
|
||||
if (nX > U(0)) {
|
||||
nA = nA / nX;
|
||||
nB = nB / nX;
|
||||
mu = nA * mu + nB * muB;
|
||||
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
||||
} else {
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U> __device__ void cuRMSOnlineSum(const U curr, U &sigma2) {
|
||||
sigma2 = sigma2 + curr * curr;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuChanRMSOnlineSum(const U sigma2B, U &sigma2) {
|
||||
sigma2 = sigma2 + sigma2B;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
|
||||
const int n2, const int i1, U &mu, U &sigma2,
|
||||
U *buf, bool rms_only) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
U count = U(0);
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const T *lvals = vals + i1 * n2;
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
U curr = static_cast<U>(lvals[l + k]);
|
||||
if (!rms_only) {
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
} else {
|
||||
cuRMSOnlineSum<U>(curr, sigma2);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
U curr = static_cast<U>(lvals[l]);
|
||||
if (!rms_only) {
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
} else {
|
||||
cuRMSOnlineSum<U>(curr, sigma2);
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
U sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize);
|
||||
if (!rms_only) {
|
||||
U muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize);
|
||||
U countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize);
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
} else {
|
||||
cuChanRMSOnlineSum<U>(sigma2B, sigma2);
|
||||
}
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
U *ubuf = (U *)buf;
|
||||
U *ibuf = (U *)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
if (!rms_only) {
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
U sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
if (!rms_only) {
|
||||
U muB = ubuf[2 * threadIdx.y];
|
||||
U countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
} else {
|
||||
cuChanRMSOnlineSum<U>(sigma2B, sigma2);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
if (!rms_only) {
|
||||
ubuf[0] = mu;
|
||||
}
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (!rms_only) {
|
||||
mu = ubuf[0];
|
||||
}
|
||||
sigma2 = ubuf[1] / U(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
if (!rms_only) {
|
||||
mu = __shfl_sync(0xffffffff, mu, 0, warpSize);
|
||||
}
|
||||
sigma2 = __shfl_sync(0xffffffff, sigma2 / U(n2), 0, warpSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, const int n1,
|
||||
const int n2, const int i1, float &mu,
|
||||
float &sigma2, float *buf, bool rms_only) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
float count = 0.0f;
|
||||
mu = float(0);
|
||||
sigma2 = float(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const __half *lvals = vals + i1 * n2;
|
||||
int l = 8 * thrx;
|
||||
if ((((size_t)lvals) & 3) != 0) {
|
||||
// 16 bit alignment
|
||||
// first thread consumes first point
|
||||
if (thrx == 0) {
|
||||
float curr = static_cast<float>(lvals[0]);
|
||||
if (!rms_only) {
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
} else {
|
||||
cuRMSOnlineSum(curr, sigma2);
|
||||
}
|
||||
}
|
||||
++l;
|
||||
}
|
||||
// at this point, lvals[l] are 32 bit aligned for all threads.
|
||||
for (; l + 7 < n2; l += 8 * numx) {
|
||||
for (int k = 0; k < 8; k += 2) {
|
||||
float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
|
||||
if (!rms_only) {
|
||||
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
|
||||
} else {
|
||||
cuRMSOnlineSum(curr.x, sigma2);
|
||||
cuRMSOnlineSum(curr.y, sigma2);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
float curr = static_cast<float>(lvals[l]);
|
||||
if (!rms_only) {
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
} else {
|
||||
cuRMSOnlineSum(curr, sigma2);
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
float sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize);
|
||||
if (!rms_only) {
|
||||
float muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize);
|
||||
float countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize);
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
} else {
|
||||
cuChanRMSOnlineSum(sigma2B, sigma2);
|
||||
}
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
float *ubuf = (float *)buf;
|
||||
float *ibuf = (float *)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
if (!rms_only) {
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
float sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
if (!rms_only) {
|
||||
float muB = ubuf[2 * threadIdx.y];
|
||||
float countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
} else {
|
||||
cuChanRMSOnlineSum(sigma2B, sigma2);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
if (!rms_only) {
|
||||
ubuf[0] = mu;
|
||||
}
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (!rms_only) {
|
||||
mu = ubuf[0];
|
||||
}
|
||||
sigma2 = ubuf[1] / float(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
if (!rms_only) {
|
||||
mu = __shfl_sync(0xffffffff, mu, 0, warpSize);
|
||||
}
|
||||
sigma2 = __shfl_sync(0xffffffff, sigma2 / float(n2), 0, warpSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the un-specialized struct. Note that we prevent instantiation of
|
||||
// this struct by putting an undefined symbol in the function body so it won't
|
||||
// compile.
|
||||
// template <typename T>
|
||||
// struct SharedMemory
|
||||
// {
|
||||
// // Ensure that we won't compile any un-specialized types
|
||||
// __device__ T *getPointer()
|
||||
// {
|
||||
// extern __device__ void error(void);
|
||||
// error();
|
||||
// return NULL;
|
||||
// }
|
||||
// };
|
||||
// https://github.com/NVIDIA/apex/issues/246
|
||||
template <typename T> struct SharedMemory;
|
||||
|
||||
template <> struct SharedMemory<float> {
|
||||
__device__ float *getPointer() {
|
||||
extern __shared__ float s_float[];
|
||||
return s_float;
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct SharedMemory<double> {
|
||||
__device__ double *getPointer() {
|
||||
extern __shared__ double s_double[];
|
||||
return s_double;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuApplyLayerNorm_(V *__restrict__ output_vals,
|
||||
U *__restrict__ mean, U *__restrict__ invvar,
|
||||
const T *__restrict__ vals, const int n1,
|
||||
const int n2, const U epsilon,
|
||||
const V *__restrict__ gamma,
|
||||
const V *__restrict__ beta, bool rms_only) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U *buf = shared.getPointer();
|
||||
U mu, sigma2;
|
||||
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only);
|
||||
|
||||
const T *lvals = vals + i1 * n2;
|
||||
V *ovals = output_vals + i1 * n2;
|
||||
U c_invvar = rsqrt(sigma2 + epsilon);
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL && (beta != NULL || rms_only)) {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
if (!rms_only) {
|
||||
ovals[i] =
|
||||
gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
||||
} else {
|
||||
ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
if (!rms_only) {
|
||||
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
||||
} else {
|
||||
ovals[i] = static_cast<V>(c_invvar * curr);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
if (!rms_only) {
|
||||
mean[i1] = mu;
|
||||
}
|
||||
invvar[i1] = c_invvar;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V = T>
|
||||
__global__ void
|
||||
cuApplyRMSNorm(V *__restrict__ output_vals, U *__restrict__ invvar,
|
||||
const T *__restrict__ vals, const int n1, const int n2,
|
||||
const U epsilon, const V *__restrict__ gamma) {
|
||||
cuApplyLayerNorm_<T, U, V>(output_vals, NULL, invvar, vals, n1, n2, epsilon,
|
||||
gamma, NULL, true);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V = T>
|
||||
void HostApplyRMSNorm(cudaStream_t stream, V *output, U *invvar, const T *input,
|
||||
int n1, int n2, double epsilon, const V *gamma) {
|
||||
auto getMaxGridY = []() {
|
||||
int device;
|
||||
int val;
|
||||
cudaGetDevice(&device);
|
||||
cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device);
|
||||
return val;
|
||||
};
|
||||
const dim3 threads(32, 4, 1);
|
||||
const uint64_t maxGridY = getMaxGridY();
|
||||
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
int nshared =
|
||||
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
|
||||
cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
|
||||
output, invvar, input, n1, n2, U(epsilon), gamma);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadWriteStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
|
||||
const T *input, const V *dout, const int i1_end, const int n2,
|
||||
const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean;
|
||||
if (!rms_only) {
|
||||
curr_mean = mean[i1];
|
||||
}
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
if (!rms_only) {
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] =
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar;
|
||||
}
|
||||
} else {
|
||||
if (!rms_only) {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
}
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (!rms_only) {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
}
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadAddStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
|
||||
const T *input, const V *dout, const int i1_end, const int n2,
|
||||
const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean;
|
||||
if (!rms_only) {
|
||||
curr_mean = mean[i1];
|
||||
}
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
if (!rms_only) {
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] +=
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuComputePartGradGammaBeta(
|
||||
const V *__restrict__ dout, const T *__restrict__ input, const int n1,
|
||||
const int n2, const U *__restrict__ mean, const U *__restrict__ invvar,
|
||||
U epsilon, U *part_grad_gamma, U *part_grad_beta, bool rms_only) {
|
||||
const int numsegs_n1 =
|
||||
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
|
||||
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
||||
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_beg_plus_one =
|
||||
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
||||
const int row_stride = blockDim.x + 1;
|
||||
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
|
||||
const int thr_load_row_off =
|
||||
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
|
||||
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
||||
SharedMemory<U> shared;
|
||||
U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
|
||||
// blockDim.y + (blockDim.y -
|
||||
// 1)*(blockDim.x/blockDim.y) elements
|
||||
U *warp_buf1 = (U *)buf;
|
||||
U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar, rms_only);
|
||||
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
|
||||
i1_block += blockDim.y * blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar, rms_only);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
// sum within each warp
|
||||
U acc1 = U(0);
|
||||
U acc2 = U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int row1 = threadIdx.y + k * blockDim.y;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
if (!rms_only) {
|
||||
acc1 += warp_buf1[idx1];
|
||||
}
|
||||
acc2 += warp_buf2[idx1];
|
||||
}
|
||||
if (!rms_only) {
|
||||
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
|
||||
}
|
||||
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
|
||||
__syncthreads();
|
||||
// sum all warps
|
||||
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + offset;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
if (!rms_only) {
|
||||
warp_buf1[idx1] += warp_buf1[idx2];
|
||||
}
|
||||
warp_buf2[idx1] += warp_buf2[idx2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadIdx.y == 0 && i2 < n2) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + 1;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
if (!rms_only) {
|
||||
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
||||
}
|
||||
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
__global__ void
|
||||
cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
|
||||
const int part_size, const int n1, const int n2,
|
||||
V *grad_gamma, V *grad_beta, bool rms_only) {
|
||||
// sum partial gradients for gamma and beta
|
||||
SharedMemory<U> shared;
|
||||
U *buf = shared.getPointer();
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i2 < n2) {
|
||||
// each warp does sequential reductions until reduced part_size is num_warps
|
||||
int num_warp_reductions = part_size / blockDim.y;
|
||||
U sum_gamma = U(0);
|
||||
U sum_beta = U(0);
|
||||
const U *part_grad_gamma_ptr =
|
||||
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
const U *part_grad_beta_ptr =
|
||||
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
for (int warp_offset = 0; warp_offset < num_warp_reductions;
|
||||
++warp_offset) {
|
||||
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
|
||||
if (!rms_only) {
|
||||
sum_beta += part_grad_beta_ptr[warp_offset * n2];
|
||||
}
|
||||
}
|
||||
// inter-warp reductions
|
||||
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
||||
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||
// top half write to shared memory
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[write_idx] = sum_gamma;
|
||||
if (!rms_only) {
|
||||
buf[write_idx + nbsize3] = sum_beta;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// bottom half sums
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_gamma += buf[read_idx];
|
||||
if (!rms_only) {
|
||||
sum_beta += buf[read_idx + nbsize3];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// write out fully summed gradients
|
||||
if (threadIdx.y == 0) {
|
||||
grad_gamma[i2] = sum_gamma;
|
||||
if (!rms_only) {
|
||||
grad_beta[i2] = sum_beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void
|
||||
cuComputeGradInput(const V *__restrict__ dout, const T *__restrict__ input,
|
||||
const int n1, const int n2, const U *__restrict__ mean,
|
||||
const U *__restrict__ invvar, U epsilon, const V *gamma,
|
||||
T *grad_input, bool rms_only) {
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
U c_mean;
|
||||
if (!rms_only) {
|
||||
c_mean = mean[i1];
|
||||
}
|
||||
const U c_invvar = invvar[i1];
|
||||
const T *k_input = input + i1 * n2;
|
||||
const V *k_dout = dout + i1 * n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
if (!rms_only) {
|
||||
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
|
||||
sum_loss2 += c_loss * static_cast<U>(gamma[l + k]) *
|
||||
(c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
sum_loss2 += c_loss * static_cast<U>(gamma[l + k]) * (c_h)*c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
if (!rms_only) {
|
||||
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
|
||||
sum_loss2 +=
|
||||
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h)*c_invvar;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
if (!rms_only) {
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
sum_loss2 += c_loss * (c_h)*c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
if (!rms_only) {
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
} else {
|
||||
sum_loss2 += c_loss * (c_h)*c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
|
||||
if (!rms_only) {
|
||||
sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize);
|
||||
}
|
||||
sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize);
|
||||
}
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
SharedMemory<U> shared;
|
||||
U *buf = shared.getPointer();
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
if (!rms_only) {
|
||||
buf[2 * wrt_i] = sum_loss1;
|
||||
}
|
||||
buf[2 * wrt_i + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
if (!rms_only) {
|
||||
sum_loss1 += buf[2 * read_i];
|
||||
}
|
||||
sum_loss2 += buf[2 * read_i + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
if (!rms_only) {
|
||||
buf[2 * threadIdx.x] = sum_loss1;
|
||||
}
|
||||
buf[2 * threadIdx.x + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y != 0) {
|
||||
if (!rms_only) {
|
||||
sum_loss1 = buf[2 * threadIdx.x];
|
||||
}
|
||||
sum_loss2 = buf[2 * threadIdx.x + 1];
|
||||
}
|
||||
}
|
||||
// all threads now have the two sums over l
|
||||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T *k_grad_input = grad_input + i1 * n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
|
||||
if (!rms_only) {
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
} else {
|
||||
f_grad_input -= (c_h)*c_invvar * sum_loss2;
|
||||
}
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
if (!rms_only) {
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
} else {
|
||||
f_grad_input -= (c_h)*c_invvar * sum_loss2;
|
||||
}
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
}
|
||||
// prevent race where buf is written again before reads are done
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = float, typename V = T>
|
||||
void HostRMSNormGradient(cudaStream_t stream, const V *dout, const U *invvar,
|
||||
const T *input, int n1, int n2, const V *gamma,
|
||||
double epsilon, T *grad_input, V *grad_gamma,
|
||||
int part_size, U *part_grad_gamma) {
|
||||
auto getMaxGridY = []() {
|
||||
int device;
|
||||
int val;
|
||||
cudaGetDevice(&device);
|
||||
cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device);
|
||||
return val;
|
||||
};
|
||||
const uint64_t maxGridY = getMaxGridY();
|
||||
if (gamma != NULL) {
|
||||
const dim3 threads2(32, 4, 1);
|
||||
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
|
||||
const int nshared2_a =
|
||||
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
|
||||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
// note (mkozuki): I can hard code part_grad_gamma's dtype as float given
|
||||
// that the `cuda_layer_norm_gradient` doesn't support double.
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout, input, n1, n2,
|
||||
invvar, // unused
|
||||
invvar, U(epsilon), part_grad_gamma, part_grad_gamma, /* unused */
|
||||
true);
|
||||
|
||||
const dim3 threads3(32, 8, 1);
|
||||
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma, part_grad_gamma, /* unused */
|
||||
part_size, n1, n2, grad_gamma, grad_gamma, /* unused */
|
||||
true);
|
||||
}
|
||||
|
||||
// compute grad_input
|
||||
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
const dim3 threads1(32, 4, 1);
|
||||
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout, input, n1, n2, invvar, /* unused */
|
||||
invvar, U(epsilon), gamma, grad_input, true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace gpu_ops {
|
||||
|
||||
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
|
||||
const char *opaque,
|
||||
std::size_t opaque_len) {
|
||||
const RMSNormDescriptor &d =
|
||||
*UnpackDescriptor<RMSNormDescriptor>(opaque, opaque_len);
|
||||
|
||||
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
d.x_type, d.w_type, "rms_norm_cuda_kernel",
|
||||
HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(
|
||||
stream, static_cast<scalar_t_out *>(buffers[2]),
|
||||
static_cast<accscalar_t *>(buffers[3]),
|
||||
static_cast<scalar_t_in *>(buffers[0]), d.n1, d.n2, d.eps,
|
||||
/*gamma=*/static_cast<scalar_t_out *>(buffers[1]));)
|
||||
}
|
||||
|
||||
void rms_backward_affine(cudaStream_t stream, void **buffers,
|
||||
const char *opaque, std::size_t opaque_len) {
|
||||
const RMSNormDescriptor &d =
|
||||
*UnpackDescriptor<RMSNormDescriptor>(opaque, opaque_len);
|
||||
|
||||
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
d.x_type, d.w_type, "cuComputeGradInputRMS",
|
||||
HostRMSNormGradient(
|
||||
stream,
|
||||
/*dout=*/static_cast<scalar_t_out *>(buffers[0]),
|
||||
/*invvar=*/static_cast<accscalar_t *>(buffers[1]),
|
||||
/*input=*/static_cast<scalar_t_in *>(buffers[2]), d.n1, d.n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
/*gamma=*/static_cast<scalar_t_out *>(buffers[3]), d.eps,
|
||||
/*grad_input=*/static_cast<scalar_t_in *>(buffers[4]),
|
||||
/*grad_gamma=*/static_cast<scalar_t_out *>(buffers[5]),
|
||||
d.part_grad_size,
|
||||
/*part_grad_gamma=*/static_cast<accscalar_t *>(buffers[6]));)
|
||||
}
|
||||
|
||||
} // namespace gpu_ops
|
@ -69,7 +69,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
The default value is False.
|
||||
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
|
||||
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
|
||||
i-th layer computation. It also enables enable overlapping (i+1)-th layer
|
||||
i-th layer computation. It also enables overlapping (i+1)-th layer
|
||||
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
|
||||
value is False. **There are some bugs when this flag is turned on.**
|
||||
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -326,7 +326,7 @@
|
||||
"id": "iuHqht-OYqca"
|
||||
},
|
||||
"source": [
|
||||
"The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`."
|
||||
"The outputs of the inner `jax.pmap(convolve)` have never left their devices when being fed into the outer `jax.pmap(convolve)`."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
@ -125,7 +125,7 @@ jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
|
||||
|
||||
+++ {"id": "iuHqht-OYqca"}
|
||||
|
||||
The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`.
|
||||
The outputs of the inner `jax.pmap(convolve)` have never left their devices when being fed into the outer `jax.pmap(convolve)`.
|
||||
|
||||
+++ {"id": "vEFAJXN2q3dV"}
|
||||
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -920,7 +920,7 @@
|
||||
"id": "ORMVVGZJgSVi"
|
||||
},
|
||||
"source": [
|
||||
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
|
||||
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1946,9 +1946,9 @@
|
||||
"\n",
|
||||
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
|
||||
"\n",
|
||||
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
|
||||
"\n",
|
||||
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
|
||||
"\n",
|
||||
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
|
||||
"\n",
|
||||
@ -2135,30 +2135,30 @@
|
||||
"\n",
|
||||
"There are a few ways to do this:\n",
|
||||
"\n",
|
||||
"1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
|
||||
"1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
|
||||
"\n",
|
||||
"2. You can manually set the `jax_enable_x64` configuration flag at startup:\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" # again, this only works on startup!\n",
|
||||
" from jax import config\n",
|
||||
" config.update(\"jax_enable_x64\", True)\n",
|
||||
" import jax\n",
|
||||
" jax.config.update(\"jax_enable_x64\", True)\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"3. You can parse command-line flags with `absl.app.run(main)`\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax import config\n",
|
||||
" config.config_with_absl()\n",
|
||||
" import jax\n",
|
||||
" jax.config.config_with_absl()\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
|
||||
"\n",
|
||||
" ```python\n",
|
||||
" from jax import config\n",
|
||||
" import jax\n",
|
||||
" if __name__ == '__main__':\n",
|
||||
" # calls config.config_with_absl() *and* runs absl parsing\n",
|
||||
" config.parse_flags_with_absl()\n",
|
||||
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
|
||||
" jax.config.parse_flags_with_absl()\n",
|
||||
" ```\n",
|
||||
"\n",
|
||||
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
@ -407,7 +407,7 @@ print(np.random.random())
|
||||
|
||||
+++ {"id": "ORMVVGZJgSVi"}
|
||||
|
||||
Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
|
||||
Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 7Pyp2ajzfPO2
|
||||
@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
|
||||
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
|
||||
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
|
||||
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
|
||||
|
||||
@ -1081,30 +1081,30 @@ To use double-precision numbers, you need to set the `jax_enable_x64` configurat
|
||||
|
||||
There are a few ways to do this:
|
||||
|
||||
1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.
|
||||
1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.
|
||||
|
||||
2. You can manually set the `jax_enable_x64` configuration flag at startup:
|
||||
|
||||
```python
|
||||
# again, this only works on startup!
|
||||
from jax import config
|
||||
config.update("jax_enable_x64", True)
|
||||
import jax
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
```
|
||||
|
||||
3. You can parse command-line flags with `absl.app.run(main)`
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
config.config_with_absl()
|
||||
import jax
|
||||
jax.config.config_with_absl()
|
||||
```
|
||||
|
||||
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
|
||||
|
||||
```python
|
||||
from jax import config
|
||||
import jax
|
||||
if __name__ == '__main__':
|
||||
# calls config.config_with_absl() *and* runs absl parsing
|
||||
config.parse_flags_with_absl()
|
||||
# calls jax.config.config_with_absl() *and* runs absl parsing
|
||||
jax.config.parse_flags_with_absl()
|
||||
```
|
||||
|
||||
Note that #2-#4 work for _any_ of JAX's configuration options.
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -17,11 +17,7 @@
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
|
||||
"\n",
|
||||
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.\n",
|
||||
"\n",
|
||||
"Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`.\n",
|
||||
"\n",
|
||||
"**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs."
|
||||
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2369,6 +2365,7 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "TPU",
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
@ -21,10 +21,6 @@ kernelspec:
|
||||
|
||||
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.
|
||||
|
||||
Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`.
|
||||
|
||||
**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs.
|
||||
|
||||
```{code-cell}
|
||||
:id: FNxScTfq3vGF
|
||||
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -7,7 +7,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3 (ipykernel)
|
||||
language: python
|
||||
|
@ -147,8 +147,10 @@ grid axes over cores. This is an opt-in procedure. To allow that,
|
||||
..
|
||||
pallas_call(
|
||||
...,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
|
74
docs/persistent_compilation_cache.md
Normal file
74
docs/persistent_compilation_cache.md
Normal file
@ -0,0 +1,74 @@
|
||||
# Persistent Compilation Cache
|
||||
|
||||
JAX has an optional disk cache for compiled programs. If enabled, JAX will
|
||||
store copies of compiled programs on disk, which can save recompilation time
|
||||
when running the same or similar tasks repeatedly.
|
||||
|
||||
## Usage
|
||||
|
||||
The compilation cache is enabled when the
|
||||
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
|
||||
is set. This should be done prior to the first compilation. Set the location as
|
||||
follows:
|
||||
|
||||
```
|
||||
import jax
|
||||
|
||||
# Make sure this is called before jax runs any operations!
|
||||
jax.config.update("jax_compilation_cache_dir", "cache-location")
|
||||
```
|
||||
|
||||
See the sections below for more detail on `cache-location`.
|
||||
|
||||
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
|
||||
is an alternate way of setting `cache-location`.
|
||||
|
||||
### Local filesystem
|
||||
|
||||
`cache-location` can be a directory on the local filesystem. For example:
|
||||
|
||||
```
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
|
||||
```
|
||||
|
||||
Note: the cache does not have an eviction mechanism implemented. If the
|
||||
cache-location is a directory in the local filesystem, its size will continue
|
||||
to grow unless files are manually deleted.
|
||||
|
||||
### Google Cloud
|
||||
|
||||
When running on Google Cloud, the compilation cache can be placed on a Google
|
||||
Cloud Storage (GCS) bucket. We recommend the following configuration:
|
||||
|
||||
* Create the bucket in the same region as where the workload will run.
|
||||
|
||||
* Create the bucket in the same project as the workload’s VM(s). Ensure that
|
||||
permissions are set so that the VM(s) can write to the bucket.
|
||||
|
||||
* There is no need for replication for smaller workloads. Larger workloads
|
||||
could benefit from replication.
|
||||
|
||||
* Use “Standard” for the default storage class for the bucket.
|
||||
|
||||
* Set the soft delete policy to its shortest: 7 days.
|
||||
|
||||
* Set the object lifecycle to the expected duration of the workload run.
|
||||
For example, if the workload is expected to run for 10 days, set the object
|
||||
lifecycle to 10 days. That should cover restarts that occur during the entire
|
||||
run. Use `age` for the lifecycle condition and `Delete` for the action. See
|
||||
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
|
||||
for details. If the object lifecycle is not set, the cache will continue to
|
||||
grow since there is no eviction mechanism implemented.
|
||||
|
||||
* All encryption policies are supported.
|
||||
|
||||
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
|
||||
follows:
|
||||
|
||||
```
|
||||
import jax
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
|
||||
```
|
@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from jax import config
|
||||
config.update("jax_numpy_rank_promotion", "warn")
|
||||
import jax
|
||||
jax.config.update("jax_numpy_rank_promotion", "warn")
|
||||
|
||||
You can also set the option using the environment variable
|
||||
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.0
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
|
@ -15,7 +15,7 @@ or deployed codebases.
|
||||
device_memory_profiling
|
||||
debugging/index
|
||||
gpu_performance_tips
|
||||
|
||||
persistent_compilation_cache
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from examples import kernel_lsq
|
||||
sys.path.pop()
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
|
@ -17,10 +17,11 @@
|
||||
|
||||
from absl import app
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import grad
|
||||
from jax import jit
|
||||
from jax import vmap
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import jax.random as random
|
||||
import jax.scipy as scipy
|
||||
@ -125,5 +126,5 @@ def main(unused_argv):
|
||||
mu.flatten() - std * 2, mu.flatten() + std * 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
|
@ -201,6 +201,7 @@ py_library_providing_imports_info(
|
||||
"_src/debugging.py",
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/earray.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
"_src/interpreters/ad.py",
|
||||
@ -997,7 +998,11 @@ pytype_library(
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_host_callback",
|
||||
srcs = ["experimental/host_callback.py"],
|
||||
srcs = [
|
||||
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
||||
"experimental/host_callback.py",
|
||||
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jax",
|
||||
|
@ -125,6 +125,7 @@ from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
from jax._src.api import vmap as vmap
|
||||
from jax._src.api import xla_computation as xla_computation
|
||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||
|
||||
# Force import, allowing jax.interpreters.* to be used after import jax.
|
||||
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
|
||||
|
@ -2559,7 +2559,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
# TODO(jakevdp): provide a default for devices that considers both local
|
||||
# devices and pods
|
||||
if not isinstance(shards, Sequence):
|
||||
raise ValueError("device_put_sharded `shards` input must be a sequence; "
|
||||
raise TypeError("device_put_sharded `shards` input must be a sequence; "
|
||||
f"got {type(shards)}")
|
||||
if len(shards) != len(devices):
|
||||
raise ValueError(f"len(shards) = {len(shards)} must equal "
|
||||
@ -2911,7 +2911,7 @@ def named_scope(
|
||||
... return jax.nn.relu(logits)
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("named_scope name argument must be a string.")
|
||||
raise TypeError("named_scope name argument must be a string.")
|
||||
with source_info_util.extend_name_stack(name):
|
||||
yield
|
||||
|
||||
|
@ -30,11 +30,9 @@ from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import layout
|
||||
from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
@ -47,10 +45,10 @@ from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
device_replica_id_map, hashed_index)
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.layout import DeviceLocalLayout, Layout
|
||||
from jax._src.typing import ArrayLike, DLDeviceType
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
|
||||
deprecations.register(__name__, "device-method")
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
@ -406,11 +404,25 @@ class ArrayImpl(basearray.Array):
|
||||
kwds = {} if copy is None else {'copy': copy}
|
||||
return np.asarray(self._value, dtype=dtype, **kwds)
|
||||
|
||||
def __dlpack__(self, *, stream: int | Any | None = None):
|
||||
if len(self._arrays) != 1:
|
||||
raise BufferError("__dlpack__ only supported for unsharded arrays.")
|
||||
def __dlpack__(self, *, stream: int | Any | None = None,
|
||||
max_version: tuple[int, int] | None = None,
|
||||
dl_device: tuple[DLDeviceType, int] | None = None,
|
||||
copy: bool | None = None):
|
||||
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
return to_dlpack(self, stream=stream)
|
||||
|
||||
device_set = self.sharding.device_set
|
||||
if len(device_set) > 1:
|
||||
raise BufferError(
|
||||
"to_dlpack can only pack a dlpack tensor from an array on a singular "
|
||||
f"device, but an array with a Sharding over {len(device_set)} devices "
|
||||
"was provided."
|
||||
)
|
||||
device, = device_set
|
||||
return to_dlpack(self, stream=stream,
|
||||
max_version=max_version,
|
||||
src_device=device,
|
||||
dl_device=dl_device,
|
||||
copy=copy)
|
||||
|
||||
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
|
||||
if len(self._arrays) != 1:
|
||||
@ -471,21 +483,6 @@ class ArrayImpl(basearray.Array):
|
||||
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
|
||||
return per_shard_size * len(self.sharding.device_set)
|
||||
|
||||
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
||||
def device(self) -> Device:
|
||||
if deprecations.is_accelerated(__name__, "device-method"):
|
||||
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
|
||||
else:
|
||||
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._check_if_deleted()
|
||||
device_set = self.sharding.device_set
|
||||
if len(device_set) == 1:
|
||||
single_device, = device_set
|
||||
return single_device
|
||||
raise ValueError('Length of devices is greater than 1. '
|
||||
'Please use `.devices()`.')
|
||||
|
||||
def devices(self) -> set[Device]:
|
||||
self._check_if_deleted()
|
||||
return self.sharding.device_set
|
||||
@ -531,13 +528,15 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
# TODO(yashkatariya): Remove the deleted check from here.
|
||||
if self.is_deleted():
|
||||
return Layout(None, self.sharding)
|
||||
try:
|
||||
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
|
||||
self.sharding)
|
||||
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
return layout.Layout(None, self.sharding)
|
||||
return Layout(None, self.sharding)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -196,7 +196,6 @@ class Array(abc.ABC):
|
||||
def block_until_ready(self) -> Array: ...
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
def delete(self) -> None: ...
|
||||
def device(self) -> Device: ...
|
||||
def devices(self) -> set[Device]: ...
|
||||
@property
|
||||
def sharding(self) -> Sharding: ...
|
||||
|
@ -14,14 +14,12 @@
|
||||
"""Module for JAX callbacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -33,9 +31,10 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -73,11 +72,14 @@ def pure_callback_impl(
|
||||
vectorized: bool,
|
||||
):
|
||||
del sharding, vectorized, result_avals
|
||||
try:
|
||||
return callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.pure_callback failed")
|
||||
raise
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
|
||||
with jax.default_device(cpu_device):
|
||||
try:
|
||||
return tree_util.tree_map(np.asarray, callback(*args))
|
||||
except BaseException:
|
||||
logger.exception("jax.pure_callback failed")
|
||||
raise
|
||||
|
||||
|
||||
pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
@ -398,11 +400,14 @@ def io_callback_impl(
|
||||
ordered: bool,
|
||||
):
|
||||
del result_avals, sharding, ordered
|
||||
try:
|
||||
return callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.io_callback failed")
|
||||
raise
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
|
||||
with jax.default_device(cpu_device):
|
||||
try:
|
||||
return tree_util.tree_map(np.asarray, callback(*args))
|
||||
except BaseException:
|
||||
logger.exception("jax.io_callback failed")
|
||||
raise
|
||||
|
||||
|
||||
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
@ -439,16 +444,16 @@ def io_callback_batching_rule(
|
||||
):
|
||||
if ordered:
|
||||
raise ValueError("Cannot `vmap` ordered IO callback.")
|
||||
return pure_callback_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
callback=callback,
|
||||
sharding=sharding,
|
||||
vectorized=False,
|
||||
result_avals=result_avals,
|
||||
)
|
||||
|
||||
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(batched_args):
|
||||
merged = util.merge_lists(is_batched, unbatched_args, batched_args)
|
||||
return io_callback_p.bind(*merged, callback=callback, sharding=sharding,
|
||||
result_avals=result_avals, ordered=False)
|
||||
out_vals = lax_map(_batch_fun, batched_args)
|
||||
return out_vals, (0,) * len(out_vals)
|
||||
batching.primitive_batchers[io_callback_p] = io_callback_batching_rule
|
||||
|
||||
|
||||
|
@ -895,9 +895,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
error_checks[lax.while_p] = while_loop_error_check
|
||||
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
inline, keep_unused):
|
||||
in_shardings, out_shardings,
|
||||
in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, inline, keep_unused):
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
@ -908,10 +908,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# Update pjit params to account for extra error values.
|
||||
num_error_vals = len(err_vals)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
||||
sharding = sharding_impls.UNSPECIFIED
|
||||
|
||||
sharding = sharding_impls.UNSPECIFIED
|
||||
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
||||
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
||||
new_in_layouts = (*[None] * num_error_vals, *in_layouts)
|
||||
new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
|
||||
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
|
||||
|
||||
err_and_out = pjit.pjit_p.bind(
|
||||
@ -919,6 +921,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
jaxpr=checked_jaxpr,
|
||||
in_shardings=new_in_shardings,
|
||||
out_shardings=new_out_shardings,
|
||||
in_layouts=new_in_layouts,
|
||||
out_layouts=new_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=new_donated_invars,
|
||||
name=name,
|
||||
@ -1296,6 +1300,6 @@ def check_error(error: Error) -> None:
|
||||
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
||||
"""
|
||||
if not isinstance(error, Error):
|
||||
raise ValueError('check_error takes an Error as argument, '
|
||||
raise TypeError('check_error takes an Error as argument, '
|
||||
f'got type {type(error)} instead.')
|
||||
_check_error(error, debug=False)
|
||||
|
@ -628,7 +628,7 @@ def define_string_state(
|
||||
|
||||
def validator(new_val):
|
||||
if not isinstance(new_val, str):
|
||||
raise ValueError('new string config value must be of type str,'
|
||||
raise TypeError('new string config value must be of type str,'
|
||||
f' got {new_val} of type {type(new_val)}.')
|
||||
|
||||
return define_string_or_object_state(
|
||||
@ -1390,6 +1390,13 @@ eager_pmap = define_bool_state(
|
||||
upgrade=True,
|
||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||
|
||||
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
|
||||
custom_vjp_disable_shape_check = define_bool_state(
|
||||
name='jax_custom_vjp_disable_shape_check',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help='Disable the check from #19009 to enable some custom_vjp hacks.')
|
||||
|
||||
xla_runtime_errors = define_bool_state(
|
||||
name='jax_experimental_unsafe_xla_runtime_errors',
|
||||
default=False,
|
||||
|
@ -54,6 +54,7 @@ from jax._src.lib import jax_jit
|
||||
from jax._src import traceback_util
|
||||
from jax._src.typing import Array, DimSize, Shape
|
||||
from jax._src import typing
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -832,14 +833,14 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
|
||||
@property
|
||||
def block_until_ready(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'block_until_ready' method is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
@property
|
||||
def copy_to_host_async(self):
|
||||
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
|
||||
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
|
||||
raise AttributeError(self,
|
||||
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
@ -849,11 +850,6 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
f"The delete() method was called on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
def device(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The device() method was called on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
def devices(self):
|
||||
raise ConcretizationTypeError(self,
|
||||
f"The devices() method was called on {self._error_repr()}."
|
||||
@ -910,10 +906,20 @@ class EvalTrace(Trace):
|
||||
lift = sublift = pure
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
return primitive.impl(*tracers, **params)
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
|
||||
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
|
||||
else:
|
||||
return primitive.impl(*tracers, **params)
|
||||
|
||||
def process_call(self, primitive, f, tracers, params):
|
||||
return primitive.impl(f, *tracers, **params)
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
|
||||
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
|
||||
else:
|
||||
return primitive.impl(f, *tracers, **params)
|
||||
process_map = process_call
|
||||
|
||||
def process_custom_transpose(self, primitive, call, tracers, **_):
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Optional
|
||||
@ -36,6 +37,17 @@ Array = jnp.ndarray
|
||||
DType = jnp.dtype
|
||||
PRNGKey = jnp.ndarray
|
||||
|
||||
class AttentionLayout(Enum):
|
||||
BTNH = 0
|
||||
BNTH = 1
|
||||
|
||||
def _normalize_layout(layout: str) -> AttentionLayout:
|
||||
layout_upper = layout.upper()
|
||||
if layout_upper in ['BSNH', 'BNSH', 'BTNH', 'BNTH']:
|
||||
return AttentionLayout[layout_upper.replace('S', 'T')]
|
||||
else:
|
||||
raise ValueError(f"Unsupported qkv_layout: {layout}")
|
||||
|
||||
def element_type_to_backend_config_type_mapping(dtype):
|
||||
_element_type_to_backend_config_type_mapping = {
|
||||
ir.BF16Type.get(): "BF16",
|
||||
@ -56,18 +68,18 @@ def create_dot_product_attention_backend_config(batch,
|
||||
dropout_rate,
|
||||
is_flash_attention,
|
||||
is_causal_mask,
|
||||
layout,
|
||||
is_bwd):
|
||||
# b q_seq num_heads head_dim -> Q
|
||||
# b kv_seq num_heads head_dim -> K
|
||||
# b kv_seq num_heads head_dim -> V
|
||||
# b num_heads q_seq kv_seq -> P
|
||||
# b q_seq num_heads head_dim -> O
|
||||
# bmm1: Q @ K -> P
|
||||
# bmm2: P @ V -> O
|
||||
# bmm2Grad1: P @ dO -> dV
|
||||
# bmm2Grad2: dO @ V -> dP
|
||||
# bmm1Grad1: dP @ Q -> dK
|
||||
# bmm1Grad2: dP @ K -> dQ
|
||||
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
|
||||
# P: BMM1 output in shape of BNTS
|
||||
# O: BMM2 output in the same shape with Q
|
||||
# BMM1: Q @ K -> P
|
||||
# BMM2: P @ V -> O
|
||||
# BMM1Grad1: dP @ Q -> dK
|
||||
# BMM1Grad2: dP @ K -> dQ
|
||||
# BMM2Grad1: P @ dO -> dV
|
||||
# BMM2Grad2: dO @ V -> dP
|
||||
|
||||
cudnn_fmha_backend_config = {
|
||||
"algorithm": {
|
||||
"algo_id": "0",
|
||||
@ -100,46 +112,47 @@ def create_dot_product_attention_backend_config(batch,
|
||||
"is_flash_attention": is_flash_attention,
|
||||
"is_causal_mask": is_causal_mask,
|
||||
}
|
||||
fwd_dot_number = {
|
||||
"bmm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["3"],
|
||||
"lhs_batch_dimensions": ["0", "2"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
}
|
||||
bwd_dot_number = {
|
||||
"bmm1_grad_gemm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["2"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm1_grad_gemm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_grad_gemm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["2"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_grad_gemm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["3"],
|
||||
"lhs_batch_dimensions": ["0", "2"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
}
|
||||
|
||||
# We define the contracting and batch dims in the format of
|
||||
# ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
||||
# rhs_batch_dims)).
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
dims = [
|
||||
((3, 3), ((0, 1), (0, 1))), # BMM1: BNTH,BNSH->BNTS
|
||||
((3, 2), ((0, 1), (0, 1))), # BMM2: BNTS,BNSH->BNTH
|
||||
((2, 2), ((0, 1), (0, 1))), # BMM1_grad_1: BNTS,BNTH->BNSH
|
||||
((3, 2), ((0, 1), (0, 1))), # BMM1_grad_2: BNTS,BNSH->BNTH
|
||||
((2, 2), ((0, 1), (0, 1))), # BMM2_grad_1: BNTS,BNTH->BNSH
|
||||
((3, 3), ((0, 1), (0, 1))), # BMM2_grad_2: BNTH,BNSH->BNTS
|
||||
]
|
||||
else:
|
||||
dims = [
|
||||
((3, 3), ((0, 2), (0, 2))), # BMM1: BTNH,BSNH->BNTS
|
||||
((3, 1), ((0, 1), (0, 2))), # BMM2: BNTS,BSNH->BTNH
|
||||
((2, 1), ((0, 1), (0, 2))), # BMM1_grad_1: BNTS,BTNH->BSNH
|
||||
((3, 1), ((0, 1), (0, 2))), # BMM1_grad_2: BNTS,BSNH->BTNH
|
||||
((2, 1), ((0, 1), (0, 2))), # BMM2_grad_1: BNTS,BTNH->BSNH
|
||||
((3, 3), ((0, 2), (0, 2))), # BMM2_grad_2: BTNH,BSNH->BNTS
|
||||
]
|
||||
keys = [
|
||||
"bmm1_dot_dimension_numbers",
|
||||
"bmm2_dot_dimension_numbers",
|
||||
"bmm1_grad_gemm1_dot_dimension_numbers",
|
||||
"bmm1_grad_gemm2_dot_dimension_numbers",
|
||||
"bmm2_grad_gemm1_dot_dimension_numbers",
|
||||
"bmm2_grad_gemm2_dot_dimension_numbers",
|
||||
]
|
||||
fwd_dot_number = {}
|
||||
bwd_dot_number = {}
|
||||
for idx, (key, ((lc, rc), (lb, rb))) in enumerate(zip(keys, dims)):
|
||||
dims_to_write = fwd_dot_number if idx < 2 else bwd_dot_number
|
||||
dims_to_write[key] = {
|
||||
"lhs_contracting_dimensions": [str(lc)],
|
||||
"rhs_contracting_dimensions": [str(rc)],
|
||||
"lhs_batch_dimensions": [str(i) for i in lb],
|
||||
"rhs_batch_dimensions": [str(i) for i in rb],
|
||||
}
|
||||
|
||||
if is_bwd:
|
||||
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number}
|
||||
else:
|
||||
@ -178,46 +191,59 @@ _custom_name_maps = {
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
return _custom_name_maps[(is_bwd, has_dropout, has_mask, has_bias)]
|
||||
|
||||
def check_qkv_layout(query, key, value):
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
|
||||
"query, key and value should have rank 4."
|
||||
def check_qkv_layout(query, key, value, layout):
|
||||
def check_eq(a, b, c, msg):
|
||||
if not (a == b == c):
|
||||
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
|
||||
|
||||
# Only support fp16 and bf16 here
|
||||
query_dtype = query.dtype
|
||||
key_dtype = key.dtype
|
||||
value_dtype = value.dtype
|
||||
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
|
||||
"query, key and value should have same dtype and should be float16 or bfloat16"
|
||||
q_rank, k_rank, v_rank = len(query.shape), len(key.shape), len(value.shape)
|
||||
if q_rank != 4:
|
||||
raise ValueError(f"Q must have a rank of 4, got {q_rank}")
|
||||
check_eq(q_rank, k_rank, v_rank, 'QKV rank')
|
||||
|
||||
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
|
||||
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
|
||||
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
|
||||
if not((q_batch == k_batch == v_batch)
|
||||
and (k_seq_len == v_seq_len)
|
||||
and (q_num_heads == k_num_heads == v_num_heads)
|
||||
and (q_head_dim == k_head_dim == v_head_dim)):
|
||||
raise ValueError(
|
||||
"query should have layout [batch, q_seq, num_heads, head_dim], " \
|
||||
"key and value should have layout [batch, kv_seq, num_heads, head_dim].")
|
||||
q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype
|
||||
assert q_dtype in [jnp.float16, jnp.bfloat16], "Q must be fp16 or bf16"
|
||||
check_eq(q_dtype, k_dtype, v_dtype, 'QKV dtype')
|
||||
|
||||
def check_is_flash_attention(query, key, cudnn_version, has_bias, is_training):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
if layout == AttentionLayout.BNTH:
|
||||
qB, qN, _, qH = query.shape
|
||||
kB, kN, kS, kH = key.shape
|
||||
vB, vN, vS, vH = value.shape
|
||||
else:
|
||||
assert layout == AttentionLayout.BTNH
|
||||
qB, _, qN, qH = query.shape
|
||||
kB, kS, kN, kH = key.shape
|
||||
vB, vS, vN, vH = value.shape
|
||||
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len <= 512 and kv_seq_len <= 512 and head_dim == 64 \
|
||||
and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0):
|
||||
check_eq(qB, kB, vB, 'QKV batch')
|
||||
check_eq(qN, kN, vN, 'QKV num_head')
|
||||
check_eq(qH, kH, vH, 'QKV dim_per_head')
|
||||
if kS != vS:
|
||||
raise ValueError(f'KV must have same seq length, got {kS} vs {vS}')
|
||||
|
||||
def check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, has_bias, is_training):
|
||||
if layout == AttentionLayout.BNTH:
|
||||
_, N, T, H = query.shape
|
||||
_, _, S, _ = key.shape
|
||||
else:
|
||||
_, T, N, H = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
|
||||
# check if attention pattern is supported by flash attention or fused attention.
|
||||
if ((T <= 512 and S <= 512 and H == 64) and
|
||||
(not is_training or T % 64 == 0 and S % 64 == 0)):
|
||||
# check if regular fused attention is supported
|
||||
# for training, seqlen should be divisible by 64
|
||||
is_flash_attention = False
|
||||
elif head_dim <= 128 and head_dim % 8 == 0 \
|
||||
and (not is_training or not has_bias or q_seq_len % 2 == 0 and kv_seq_len % 2 == 0):
|
||||
elif ((H <= 128 and H % 8 == 0) and
|
||||
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
|
||||
# check if flash attention is supported
|
||||
# for training, for patterns with bias, seqlen should be divisible by 2
|
||||
is_flash_attention = True
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {q_seq_len}, KV {kv_seq_len} and head dim {head_dim}.")
|
||||
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}.")
|
||||
# check if minimum cudnn version requirement is satisfied
|
||||
if is_flash_attention and cudnn_version < 8904:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
|
||||
@ -232,61 +258,78 @@ def check_cudnn_version():
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
return cuda_versions.cudnn_get_version()
|
||||
|
||||
def _dot_product_attention_fwd(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd(
|
||||
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout, is_training):
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
query, key, value, bias, mask, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
output = outputs[0]
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd_rule(
|
||||
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout, is_training):
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
query, key, value, bias, mask, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
res = (query, key, value, bias, mask, outputs[1], outputs[0]) if is_training else None
|
||||
return outputs[0], res
|
||||
|
||||
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, res, grad_output):
|
||||
def _dot_product_attention_bwd_rule(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout, is_training, res, grad_output):
|
||||
query, key, value, bias, mask, activation, fwd_output = res
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, layout=layout,
|
||||
)
|
||||
grads = (grad_query, grad_key, grad_value, None, None)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd_impl(
|
||||
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout, is_training):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
outputs = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
query, key, value, bias, mask, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
return outputs
|
||||
|
||||
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
def _dot_product_attention_bwd_impl(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output, scale,
|
||||
seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask,
|
||||
layout):
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, layout=layout,
|
||||
)
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd_abstract(
|
||||
query, key, value, bias, mask, *, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout, is_training):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
output_shape = (batch, q_seq_len, num_heads, head_dim)
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
B, N, T, _ = query.shape
|
||||
_, _, S, _ = key.shape
|
||||
else:
|
||||
B, T, N, _ = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
output_shape = query.shape
|
||||
activation_shape = (B, N, T, S)
|
||||
softmax_stat_shape = (B, N, T)
|
||||
|
||||
if is_flash_attention:
|
||||
# is flash attention
|
||||
@ -309,8 +352,10 @@ def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
def _dot_product_attention_bwd_abstract(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output, *,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
@ -327,8 +372,9 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
|
||||
), # part value
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, mask, scale, seed, dropout_rate,
|
||||
variadic_args, is_flash_attention, is_causal_mask, layout, is_training):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -336,18 +382,26 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
value_type = ir.RankedTensorType(value.type)
|
||||
value_shape = value_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
B, N, T, H = query_shape
|
||||
_, _, S, _ = key_shape
|
||||
output_layout = (3, 2, 1, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
|
||||
else:
|
||||
B, T, N, H = query_shape
|
||||
_, S, _, _ = key_shape
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
|
||||
output_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
output_shape = (B, N, T, H)
|
||||
activation_shape = (B, N, T, S)
|
||||
softmax_stat_shape = (B, N, T)
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
# get backend config
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False)
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
is_flash_attention, is_causal_mask, layout, is_bwd=False,
|
||||
)
|
||||
# {Q, K, V, mask*, bias*}
|
||||
# {output, scratch, activation*}
|
||||
has_dropout = dropout_rate > 0
|
||||
@ -403,8 +457,10 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
else:
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm)]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
def _dot_product_attention_bwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -416,18 +472,28 @@ def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
grad_output_type = ir.RankedTensorType(grad_output.type)
|
||||
grad_output_shape = grad_output_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
B, N, T, H = query_shape
|
||||
_, _, S, _ = key_shape
|
||||
grad_layout = (3, 2, 1, 0)
|
||||
grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
|
||||
else:
|
||||
B, T, N, H = query_shape
|
||||
_, S, _, _ = key_shape
|
||||
grad_layout = (3, 1, 2, 0)
|
||||
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
|
||||
grad_query_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
grad_key_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
grad_value_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
softmax_sum_shape = (batch, num_heads, q_seq_len)
|
||||
grad_layout = (3, 1, 2, 0)
|
||||
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True)
|
||||
grad_query_shape = (B, N, T, H)
|
||||
grad_key_shape = (B, N, S, H)
|
||||
grad_value_shape = (B, N, S, H)
|
||||
softmax_sum_shape = (B, N, T)
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
is_flash_attention, is_causal_mask, layout, is_bwd=True,
|
||||
)
|
||||
# {Q, K, V, activation, dO, mask*, bias*, O*}
|
||||
# {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*}
|
||||
has_dropout = dropout_rate > 0
|
||||
@ -484,7 +550,9 @@ def _check_valid_batch_dims(bdims):
|
||||
raise NotImplementedError("Currently only support batch_dim in [0, None], " \
|
||||
f"but got {dim=}")
|
||||
|
||||
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
def _dot_product_attention_fwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout, is_training):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
@ -493,74 +561,84 @@ def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed,
|
||||
else:
|
||||
out_bdims = (query_bdim,)
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
new_batch = reduce(operator.mul, batch_tuple)
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
*Bs, N, T, _ = query.shape
|
||||
*_, _, S, _ = key.shape
|
||||
else:
|
||||
*Bs, T, N, _ = query.shape
|
||||
*_, S, _, _ = key.shape
|
||||
B = reduce(operator.mul, Bs)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
query = jnp.reshape(query, (B,) + query.shape[-3:])
|
||||
key = jnp.reshape(key, (B,) + key.shape[-3:])
|
||||
value = jnp.reshape(value, (B,) + key.shape[-3:])
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
bias = jnp.reshape(bias, (B, N, T, S))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
mask = jnp.reshape(mask, (B, N, T, S))
|
||||
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
query, key, value, bias, mask, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
|
||||
# reshape to original shape
|
||||
output = outputs[0]
|
||||
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
output = jnp.reshape(output, query.shape)
|
||||
if is_training:
|
||||
activation = outputs[1]
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
|
||||
activation = jnp.reshape(activation, (*Bs, N, T))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
|
||||
activation = jnp.reshape(activation, (*Bs, N, T, S))
|
||||
return (output, activation), out_bdims
|
||||
else:
|
||||
return (output,), out_bdims
|
||||
|
||||
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
def _dot_product_attention_bwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim, query_bdim
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
new_batch = reduce(operator.mul, batch_tuple)
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
*Bs, N, T, _ = query.shape
|
||||
*_, _, S, _ = key.shape
|
||||
else:
|
||||
*Bs, T, N, _ = query.shape
|
||||
*_, S, _, _ = key.shape
|
||||
B = reduce(operator.mul, Bs)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
query = jnp.reshape(query, (B,) + query.shape[-3:])
|
||||
key = jnp.reshape(key, (B,) + key.shape[-3:])
|
||||
value = jnp.reshape(value, (B,) + key.shape[-3:])
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
bias = jnp.reshape(bias, (B, N, T, S))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
mask = jnp.reshape(mask, (B, N, T, S))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len))
|
||||
activation = jnp.reshape(activation, (B, N, T))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
fwd_output = jnp.reshape(fwd_output, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
grad_output = jnp.reshape(grad_output, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
activation = jnp.reshape(activation, (B, N, T, S))
|
||||
fwd_output = jnp.reshape(fwd_output, (B,) + query.shape[-3:])
|
||||
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
|
||||
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias,
|
||||
mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask, layout=layout,
|
||||
)
|
||||
|
||||
# reshape to original shape
|
||||
grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grad_query = jnp.reshape(grad_query, query.shape)
|
||||
grad_key = jnp.reshape(grad_key, key.shape)
|
||||
grad_value = jnp.reshape(grad_value, value.shape)
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads, out_bdims
|
||||
|
||||
@ -617,17 +695,25 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training):
|
||||
return [out_sharding, activation_sharding]
|
||||
return [out_sharding]
|
||||
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10,11))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(
|
||||
_dot_product_attention_fwd_impl, static_argnums=(5, 6, 7, 8, 9, 10, 11, 12))
|
||||
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout, is_training, mesh, arg_shapes, result_shape):
|
||||
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
|
||||
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
|
||||
def _dot_product_attention_fwd_partition(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout, is_training, mesh, arg_shapes, result_shape):
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
|
||||
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
is_training=is_training)
|
||||
impl = partial(
|
||||
_dot_product_attention_fwd_impl, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# bwd custom partition
|
||||
@ -648,16 +734,27 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args):
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
return out_shardings
|
||||
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(
|
||||
_dot_product_attention_bwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14)
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout, mesh, arg_shapes, result_shape):
|
||||
return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
|
||||
|
||||
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
def _dot_product_attention_bwd_partition(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention,
|
||||
is_causal_mask, layout, mesh, arg_shapes, result_shape):
|
||||
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
impl = partial(
|
||||
_dot_product_attention_bwd_impl, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout,
|
||||
)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# Create dot_product_attention_fwd_p for forward operation.
|
||||
@ -717,24 +814,25 @@ dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11))
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
mask: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
variadic_args: tuple[bool, ...],
|
||||
is_flash_attention: bool,
|
||||
is_causal_mask: bool,
|
||||
is_training: bool):
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
mask: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
variadic_args: tuple[bool, ...],
|
||||
is_flash_attention: bool,
|
||||
is_causal_mask: bool,
|
||||
layout: int,
|
||||
is_training: bool):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
is_training=is_training)
|
||||
query, key, value, bias, mask, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
layout=layout, is_training=is_training)
|
||||
return output
|
||||
|
||||
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
|
||||
@ -751,40 +849,51 @@ def dot_product_attention(query: Array,
|
||||
is_causal_mask: bool = False,
|
||||
seed: int = 42,
|
||||
dropout_rate: float = 0.,
|
||||
qkv_layout: str = 'BTNH',
|
||||
is_training = False):
|
||||
"""Computes dot-product attention given query, key, and value.
|
||||
This is the core function for applying attention based on
|
||||
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
||||
query and key and combines the values using the attention weights.
|
||||
batch seq num_heads, head_dim // but all assume Q, K and V will have same
|
||||
b q_seq num_heads head_dim -> Q
|
||||
b kv_seq num_heads head_dim -> K
|
||||
b kv_seq num_heads head_dim -> V
|
||||
"""Computes dot-product attention given query (Q), key (K), and value (V).
|
||||
|
||||
This function serves as the core operation for applying attention
|
||||
mechanisms as described in the paper [https://arxiv.org/abs/1706.03762].
|
||||
Initially, it determines the attention weights by processing Q and K,
|
||||
subsequently combining the outcomes using K. Throughout this function, we
|
||||
utilize the following uppercase letters to represent specific parameters of
|
||||
array:
|
||||
|
||||
B = batch size
|
||||
S = length of the key/value (source)
|
||||
T = length of the query (target)
|
||||
N = number of attention heads
|
||||
H = dimensions of each attention head.
|
||||
|
||||
The supported layouts for Q, K, V are either BT(S)NH or BNT(S)H, and they must
|
||||
adhere to the same layout. The output layout remains consistent with Q,
|
||||
defaulting to BT(S)NH.
|
||||
|
||||
Args:
|
||||
query: queries for calculating attention with shape of `[batch, q_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
key: keys for calculating attention with shape of `[batch, kv_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
value: values to be used in attention with shape of `[batch, kv_length,
|
||||
num_heads, v_depth_per_head]`.
|
||||
bias: bias to be added to logits with shape of `[batch, num_heads,
|
||||
q_length, kv_length]`.
|
||||
mask: mask used mask out logits with shape of `[batch, num_heads,
|
||||
q_length, kv_length]`.
|
||||
scale: scale for the query.
|
||||
is_causal_mask: choose to apply a causal mask or not.
|
||||
seed: used for dropout mask generation.
|
||||
dropout_rate: dropout rate.
|
||||
query: Queries for attention calculation with a shape of BTNH or BNTH.
|
||||
key: Keys for attention calculation with a shape of BSNH or BNSH.
|
||||
value: Values to be used in attention with a shape of BSNH or BNSH.
|
||||
bias: Bias to be added to logits with a shape of BNTS.
|
||||
mask: Mask used to filter out logits with a shape of BNTS.
|
||||
scale: Scale for the query.
|
||||
dropout_rate: Dropout rate.
|
||||
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
|
||||
BNSH.
|
||||
is_training: choose to save activation or not.
|
||||
|
||||
Returns:
|
||||
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
|
||||
Output of the same shape as the query.
|
||||
"""
|
||||
# check if cuDNN is installed
|
||||
cudnn_version = check_cudnn_version()
|
||||
|
||||
layout = _normalize_layout(qkv_layout)
|
||||
# check query, key and value shape and data type
|
||||
check_qkv_layout(query, key, value)
|
||||
check_qkv_layout(query, key, value, layout)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention = check_is_flash_attention(query, key, cudnn_version, bias is not None, is_training)
|
||||
is_flash_attention = check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, bias is not None, is_training)
|
||||
if mask is not None and is_causal_mask:
|
||||
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
|
||||
if not is_flash_attention and is_causal_mask:
|
||||
@ -795,7 +904,7 @@ def dot_product_attention(query: Array,
|
||||
if mask is None:
|
||||
mask = jnp.zeros(0, dtype=query.dtype)
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, is_training)
|
||||
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask, layout.value, is_training
|
||||
)
|
||||
return output
|
||||
|
@ -772,7 +772,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
results.append(Zero(ct.aval))
|
||||
else:
|
||||
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
|
||||
and not _temporary_dtype_exception(a, a_)):
|
||||
and not (_temporary_dtype_exception(a, a_) or
|
||||
_temporary_shape_exception(a, a_))):
|
||||
msg = ("Custom VJP bwd rule must produce an output with the same "
|
||||
"shape/dtypes as the args tuple of the primal function, but at "
|
||||
f"output{keystr(kp)} the bwd rule produced an output of "
|
||||
@ -790,6 +791,9 @@ def _temporary_dtype_exception(a, a_) -> bool:
|
||||
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
|
||||
return False
|
||||
|
||||
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
|
||||
def _temporary_shape_exception(a, a_) -> bool:
|
||||
return config.custom_vjp_disable_shape_check.value
|
||||
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
initial_style: core.Primitive
|
||||
|
@ -108,7 +108,7 @@ def simple_impl(prim):
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
"""See docstring for effect.py module for the calling convention for tokens."""
|
||||
"""See docstring for effects.py module for the calling convention for tokens."""
|
||||
|
||||
# For each ordered effect, the token returned by the last dispatched
|
||||
# computation, sharded over the devices in that computation.
|
||||
@ -125,6 +125,16 @@ class RuntimeTokenSet(threading.local):
|
||||
def get_token_input(self, eff: core.Effect,
|
||||
devices: list[Device]) -> jax.Array:
|
||||
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
||||
|
||||
if isinstance(tok, jax.Array):
|
||||
# The order of devices may change, so we need to reshard if necessary.
|
||||
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
|
||||
# scenario. Revise the logic later. A distributed shutdown barrier inside
|
||||
# the XLA program may be needed.
|
||||
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
|
||||
|
||||
# We only use replicated sharding for the first time when the token for the
|
||||
# order effect hasn't been created.
|
||||
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
||||
sharded_tok = pxla.shard_args([s], [tok])[0]
|
||||
self.current_tokens[eff] = sharded_tok
|
||||
@ -452,10 +462,7 @@ def _device_put_impl(
|
||||
return x
|
||||
if x_dll is None and dll is None:
|
||||
return _device_put_sharding_impl(x, aval, l.sharding)
|
||||
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
|
||||
# out_layouts from lower.
|
||||
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
|
||||
x, _out_layouts=l).compile()(x)
|
||||
return api.jit(_identity_fn, out_shardings=l)(x)
|
||||
|
||||
return _device_put_sharding_impl(x, aval, device)
|
||||
|
||||
|
@ -14,17 +14,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from jax._src.api import device_put
|
||||
from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lax.lax import _array_copy
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array
|
||||
from jax._src.typing import Array, DLDeviceType
|
||||
from jax._src.sharding import Sharding
|
||||
|
||||
DLPACK_VERSION = (0, 8)
|
||||
MIN_DLPACK_VERSION = (0, 5)
|
||||
|
||||
# A set of dtypes that dlpack supports.
|
||||
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
|
||||
@ -42,56 +45,218 @@ if xla_extension_version >= 231:
|
||||
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
|
||||
|
||||
|
||||
# Mirror of dlpack.h enum
|
||||
class DLDeviceType(enum.IntEnum):
|
||||
kDLCPU = 1
|
||||
kDLCUDA = 2
|
||||
kDLROCM = 10
|
||||
def _to_dlpack(x: Array, stream: int | Any | None,
|
||||
src_device: xla_client.Device | None = None,
|
||||
device: xla_client.Device | None = None,
|
||||
copy: bool | None = None):
|
||||
|
||||
if src_device is None:
|
||||
src_device, = x.devices()
|
||||
if device and (src_device is None or device != src_device):
|
||||
if copy is not None and not copy:
|
||||
raise ValueError(
|
||||
f"Specified {device=} which requires a copy since the source device "
|
||||
f"is {repr(src_device)}, however copy=False. Set copy=True or "
|
||||
"copy=None to perform the requested operation."
|
||||
)
|
||||
else:
|
||||
arr = device_put(x, device)
|
||||
else:
|
||||
arr = _array_copy(x) if copy else x
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
arr.addressable_data(0), stream=stream
|
||||
)
|
||||
|
||||
def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
stream: int | Any | None = None):
|
||||
def to_dlpack(x: Array, stream: int | Any | None = None,
|
||||
src_device: xla_client.Device | None = None,
|
||||
dl_device: tuple[DLDeviceType, int] | None = None,
|
||||
max_version: tuple[int, int] | None = None,
|
||||
copy : bool | None = None):
|
||||
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
||||
|
||||
Args:
|
||||
x: a :class:`~jax.Array`, on either CPU or GPU.
|
||||
take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
|
||||
deleted in 01/2024.
|
||||
stream: optional platform-dependent stream to wait on until the buffer is
|
||||
ready. This corresponds to the `stream` argument to ``__dlpack__``
|
||||
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
||||
src_device: either a CPU or GPU :class:`~jax.Device`.
|
||||
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
|
||||
format e.g. as produced by ``__dlpack_device__``.
|
||||
max_version: the maximum DLPack version that the consumer (i.e. caller of
|
||||
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
|
||||
This function is not guaranteed to return a capsule of version
|
||||
``max_version``.
|
||||
copy: a boolean indicating whether or not to copy the input. If
|
||||
``copy=True`` then the function must always copy. When
|
||||
``copy=False`` then the function must never copy, and must raise an error
|
||||
when a copy is deemed necessary. If ``copy=None`` then the function must
|
||||
avoid a copy if possible but may copy if needed.
|
||||
|
||||
Returns:
|
||||
A dlpack PyCapsule object.
|
||||
A DLPack PyCapsule object.
|
||||
|
||||
Note:
|
||||
While JAX arrays are always immutable, dlpack buffers cannot be marked as
|
||||
immutable, and it is possible for processes external to JAX to mutate them
|
||||
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
|
||||
lead to undefined behavior when using the associated JAX array.
|
||||
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
|
||||
cannot be marked as immutable, and it is possible for processes external
|
||||
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
|
||||
is mutated, it may lead to undefined behavior when using the associated JAX
|
||||
array. When JAX eventually supports ``DLManagedTensorVersioned``
|
||||
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
|
||||
"""
|
||||
if not isinstance(x, array.ArrayImpl):
|
||||
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
||||
f"got {type(x)}")
|
||||
assert len(x.devices()) == 1
|
||||
if take_ownership:
|
||||
warnings.warn(
|
||||
"take_ownership in to_dlpack is deprecated and it is a no-op."
|
||||
|
||||
device = None
|
||||
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
|
||||
if dl_device_type:
|
||||
try:
|
||||
dl_device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
backend = xla_bridge.get_backend(dl_device_platform)
|
||||
device = backend.device_from_local_hardware_id(local_hardware_id)
|
||||
except TypeError:
|
||||
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
|
||||
# recommends using BufferError.
|
||||
raise BufferError(
|
||||
"The device specification passed to to_dlpack contains an unsupported "
|
||||
f"device type (DLDeviceType: {dl_device_type})")
|
||||
|
||||
# As new versions are adopted over time, we can maintain some legacy paths
|
||||
# for compatability mediated through the max_version parameter.
|
||||
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
|
||||
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
|
||||
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
|
||||
if max_version is None or max_version >= DLPACK_VERSION:
|
||||
# Latest
|
||||
return _to_dlpack(
|
||||
x, stream=stream,
|
||||
src_device=src_device,
|
||||
device=device,
|
||||
copy=copy
|
||||
)
|
||||
elif max_version >= MIN_DLPACK_VERSION:
|
||||
# Oldest supported
|
||||
return _to_dlpack(
|
||||
x, stream=stream,
|
||||
src_device=src_device,
|
||||
device=device,
|
||||
copy=copy
|
||||
)
|
||||
else:
|
||||
raise BufferError(
|
||||
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
|
||||
f"version ({max_version}) was requested."
|
||||
)
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
x.addressable_data(0), stream=stream
|
||||
) # type: ignore
|
||||
|
||||
def _place_array(_arr, device, dlpack_device, copy):
|
||||
if device and dlpack_device != device:
|
||||
if copy is not None and not copy:
|
||||
raise ValueError(
|
||||
f"Specified {device=} which requires a copy since the source device "
|
||||
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
|
||||
"copy=None to perform the requested operation."
|
||||
)
|
||||
else:
|
||||
return device_put(_arr, device)
|
||||
if copy:
|
||||
return jnp.array(_arr, copy=True)
|
||||
return _arr
|
||||
|
||||
def from_dlpack(external_array):
|
||||
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
|
||||
copy: bool | None = None):
|
||||
preferred_platform = getattr(device, "platform", None)
|
||||
if device and preferred_platform == "gpu":
|
||||
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
|
||||
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
gpu_backend = None
|
||||
|
||||
if preferred_platform in {"cuda", "rocm"}:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend(preferred_platform)
|
||||
except RuntimeError:
|
||||
raise TypeError(
|
||||
f"A {str.upper(preferred_platform)} device was specified, however no "
|
||||
f"{str.upper(preferred_platform)} backend was found."
|
||||
)
|
||||
|
||||
if preferred_platform is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
except RuntimeError:
|
||||
pass
|
||||
# Try ROCm if CUDA backend not found
|
||||
if gpu_backend is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("rocm")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)) # type: ignore
|
||||
dlpack_device, = _arr.devices()
|
||||
return _place_array(_arr, device, dlpack_device, copy)
|
||||
|
||||
def _from_dlpack(external_array, device: xla_client.Device | None = None,
|
||||
copy: bool | None = None):
|
||||
dl_device_type, device_id = external_array.__dlpack_device__()
|
||||
try:
|
||||
dl_device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
except TypeError:
|
||||
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
||||
# TypeError.
|
||||
raise TypeError(
|
||||
"Array passed to from_dlpack is on unsupported device type "
|
||||
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
||||
|
||||
backend = xla_bridge.get_backend(dl_device_platform)
|
||||
dlpack_device = backend.device_from_local_hardware_id(device_id)
|
||||
try:
|
||||
stream = dlpack_device.get_stream_for_external_ready_events()
|
||||
except xla_client.XlaRuntimeError as err: # type: ignore
|
||||
if "UNIMPLEMENTED" in str(err):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream=stream)
|
||||
|
||||
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, dlpack_device, stream))
|
||||
return _place_array(_arr, device, dlpack_device, copy)
|
||||
|
||||
def from_dlpack(external_array,
|
||||
device: xla_client.Device | Sharding | None = None,
|
||||
copy: bool | None = None):
|
||||
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
||||
|
||||
The returned :class:`~jax.Array` shares memory with ``external_array``.
|
||||
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
|
||||
device transfer or copy was requested.
|
||||
|
||||
Args:
|
||||
external_array: an array object that has __dlpack__ and __dlpack_device__
|
||||
external_array: An array object that has __dlpack__ and __dlpack_device__
|
||||
methods, or a DLPack tensor on either CPU or GPU (legacy API).
|
||||
|
||||
device: The (optional) :py:class:`Device`, representing the device on which
|
||||
the returned array should be placed. If given, then the result is committed
|
||||
to the device. If unspecified, the resulting array will be unpacked onto the
|
||||
same device it originated from. Setting ``device`` to a device different from
|
||||
the source of ``external_array`` will require a copy, meaning ``copy`` must be
|
||||
set to either ``True`` or ``None``.
|
||||
|
||||
copy: An (optional) boolean, controlling whether or not to a copy is performed.
|
||||
If ``copy=True`` then a copy is always performed, even if unpacked onto the
|
||||
same device. If ``copy=False`` then the copy is never peformed and will raise
|
||||
an error if necessary. When ``copy=None`` then a copy may be performed if
|
||||
needed for a device transfer.
|
||||
|
||||
Returns:
|
||||
A jax.Array
|
||||
|
||||
@ -102,49 +267,16 @@ def from_dlpack(external_array):
|
||||
is later modified in-place, it may lead to undefined behavior when using
|
||||
the associated JAX array.
|
||||
"""
|
||||
if isinstance(device, Sharding):
|
||||
device_set = device.device_set
|
||||
if len(device_set) > 1:
|
||||
raise ValueError(
|
||||
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
|
||||
f"a Sharding with {len(device_set)} devices was provided."
|
||||
)
|
||||
device, = device_set
|
||||
if hasattr(external_array, "__dlpack__"):
|
||||
dl_device_type, device_id = external_array.__dlpack_device__()
|
||||
try:
|
||||
device_platform = {
|
||||
DLDeviceType.kDLCPU: "cpu",
|
||||
DLDeviceType.kDLCUDA: "cuda",
|
||||
DLDeviceType.kDLROCM: "rocm",
|
||||
}[dl_device_type]
|
||||
except TypeError:
|
||||
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
|
||||
# TypeError.
|
||||
raise TypeError(
|
||||
"Array passed to from_dlpack is on unsupported device type "
|
||||
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
|
||||
return _from_dlpack(external_array, device, copy)
|
||||
|
||||
backend = xla_bridge.get_backend(device_platform)
|
||||
device = backend.device_from_local_hardware_id(device_id)
|
||||
try:
|
||||
stream = device.get_stream_for_external_ready_events()
|
||||
except xla_client.XlaRuntimeError as err: # type: ignore
|
||||
if "UNIMPLEMENTED" in str(err):
|
||||
stream = None
|
||||
else:
|
||||
raise
|
||||
dlpack = external_array.__dlpack__(stream=stream)
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, device, stream))
|
||||
else:
|
||||
# Legacy path
|
||||
dlpack = external_array
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
# Try ROCm if CUDA backend not found
|
||||
if gpu_backend is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("rocm")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend))
|
||||
# Legacy path
|
||||
return _legacy_from_dlpack(external_array, device, copy)
|
||||
|
110
jax/_src/earray.py
Normal file
110
jax/_src/earray.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
# EArray is an Array that can contain extended dtypes.
|
||||
class EArray(basearray.Array):
|
||||
__slots__ = ['aval', '_data']
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
__array_priority__ = 100
|
||||
|
||||
def __init__(self, aval, data):
|
||||
self.aval = aval
|
||||
self._data = data
|
||||
|
||||
def block_until_ready(self):
|
||||
_ = self._data.block_until_ready()
|
||||
return self
|
||||
|
||||
def copy_to_host_async(self):
|
||||
self._data.copy_to_host_async()
|
||||
|
||||
def copy(self):
|
||||
return EArray(self.aval, self._data.copy())
|
||||
|
||||
def __repr__(self):
|
||||
return 'E' + repr(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0: raise TypeError('iteration over a 0-d array')
|
||||
raise NotImplementedError
|
||||
|
||||
# forward to aval
|
||||
shape = property(lambda self: self.aval.shape) # type: ignore[assignment]
|
||||
dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment]
|
||||
|
||||
# computed from shape and dtype
|
||||
ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment]
|
||||
size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment]
|
||||
itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment]
|
||||
def __len__(self):
|
||||
if self.ndim == 0: raise TypeError('len() of unsized object')
|
||||
return self.shape[0]
|
||||
|
||||
# forward to self._data
|
||||
devices = property(lambda self: self._data.devices) # type: ignore[assignment]
|
||||
_committed = property(lambda self: self._data._committed)
|
||||
is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment]
|
||||
is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment]
|
||||
delete = property(lambda self: self._data.delete) # type: ignore[assignment]
|
||||
is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment]
|
||||
on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment]
|
||||
unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment]
|
||||
|
||||
# defer to extended dtype rules
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._data.sharding
|
||||
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
|
||||
|
||||
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
||||
|
||||
def addressable_data(self, index: int) -> EArray:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def addressable_shards(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def global_shards(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(mattjj): _set_array_base_attributes
|
||||
|
||||
def _earray_shard_arg_handler(x, sharding):
|
||||
arr = x._data
|
||||
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
||||
tree_util.dispatch_registry.register_node(
|
||||
EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))
|
@ -23,6 +23,7 @@ import collections
|
||||
import itertools
|
||||
from typing import Union, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util
|
||||
@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -75,6 +75,13 @@ _JAX_DUMP_IR_TO = config.DEFINE_string(
|
||||
"Supports the special value 'sponge' to pick the path from the "
|
||||
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
|
||||
|
||||
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
|
||||
'jax_include_debug_info_in_dumps',
|
||||
os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
|
||||
help="Determine whether or not to keep debug symbols and location information "
|
||||
"when dumping IR code. By default, debug information will be preserved in "
|
||||
"the IR dump. To avoid exposing source code and potentially sensitive "
|
||||
"information, set to false")
|
||||
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
|
||||
|
||||
|
||||
@ -474,9 +481,12 @@ def dump_module_message(module: ir.Module, stage_name: str) -> str:
|
||||
def _make_string_safe_for_filename(s: str) -> str:
|
||||
return re.sub(r'[^\w.)( -]', '', s)
|
||||
|
||||
def module_to_string(module: ir.Module) -> str:
|
||||
def module_to_string(module: ir.Module, enable_debug_info=None) -> str:
|
||||
output = io.StringIO()
|
||||
module.operation.print(file=output, enable_debug_info=True)
|
||||
if enable_debug_info is None:
|
||||
enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value)
|
||||
enable_debug_info = enable_debug_flag not in ('false', '0')
|
||||
module.operation.print(file=output, enable_debug_info=enable_debug_info)
|
||||
return output.getvalue()
|
||||
|
||||
def module_to_bytecode(module: ir.Module) -> bytes:
|
||||
@ -944,11 +954,6 @@ def lower_jaxpr_to_module(
|
||||
else:
|
||||
dim_vars = ()
|
||||
|
||||
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
|
||||
else in_layouts)
|
||||
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
|
||||
else out_layouts)
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
keepalives=keepalives,
|
||||
@ -982,8 +987,8 @@ def lower_jaxpr_to_module(
|
||||
result_names=result_names,
|
||||
arg_memory_kinds=arg_memory_kinds,
|
||||
result_memory_kinds=result_memory_kinds,
|
||||
arg_layouts=arg_layouts,
|
||||
result_layouts=result_layouts)
|
||||
arg_layouts=in_layouts,
|
||||
result_layouts=out_layouts)
|
||||
|
||||
try:
|
||||
if not ctx.module.operation.verify():
|
||||
@ -1130,8 +1135,8 @@ def lower_jaxpr_to_fun(
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
arg_memory_kinds: Sequence[str | None] | None = None,
|
||||
result_memory_kinds: Sequence[str | None] | None = None,
|
||||
arg_layouts: Sequence[str | None] | None = None,
|
||||
result_layouts: Sequence[str | None] | None = None,
|
||||
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
) -> func_dialect.FuncOp:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
@ -1252,7 +1257,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_arg_layouts = None
|
||||
if arg_layouts is not None:
|
||||
ir_arg_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(arg_layouts, input_types)])
|
||||
|
||||
ir_donated_args = None
|
||||
if xla_donated_args is not None:
|
||||
@ -1275,7 +1281,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_result_layouts = None
|
||||
if result_layouts is not None:
|
||||
ir_result_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(result_layouts, output_types)])
|
||||
|
||||
if (
|
||||
replicated_args is not None
|
||||
|
@ -1004,19 +1004,6 @@ class UnloadedPmapExecutable:
|
||||
shards.out_sharded_avals, pci.out_axes)]
|
||||
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
|
||||
if hasattr(pci.backend, "compile_replicated"):
|
||||
input_indices = [
|
||||
sharding_specs.spec_to_indices(aval.shape, spec)
|
||||
if spec is not None else None
|
||||
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
|
||||
]
|
||||
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
|
||||
out_shardings)
|
||||
return _compile_replicated_pmap_executable_from_hlo(
|
||||
hlo, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, bool(unordered_effects),
|
||||
ordered_effects, jaxpr_debug_info)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
@ -1038,23 +1025,6 @@ class UnloadedPmapExecutable:
|
||||
jaxpr_debug_info=jaxpr_debug_info).load()
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
|
||||
jaxpr_debug_info):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=hlo,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
|
||||
jaxpr_debug_info, None)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
||||
@ -1185,13 +1155,25 @@ class ExecuteReplicated:
|
||||
|
||||
def _handle_token_bufs(self, token_bufs, sharded_token):
|
||||
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
|
||||
# token buffer (as a singleton list).
|
||||
# token buffers.
|
||||
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
|
||||
for i, device in enumerate(self._local_devices):
|
||||
dispatch.runtime_tokens.set_output_runtime_token(
|
||||
device, sharded_token.get_token(i))
|
||||
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
||||
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
||||
assert len(token_buf) > 0
|
||||
if len(token_buf) == 1:
|
||||
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
||||
else:
|
||||
token_devices = []
|
||||
for token in token_buf:
|
||||
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
|
||||
token_devices.append(token.sharding._device_assignment[0])
|
||||
s = sharding_impls.PositionalSharding(token_devices)
|
||||
global_token_array = jax.make_array_from_single_device_arrays(
|
||||
(0,), s, token_buf
|
||||
)
|
||||
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
|
||||
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
@ -2007,10 +1989,8 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
|
||||
"""Avals and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
in_shardings: Any
|
||||
in_layouts: Any
|
||||
debug_info: core.JaxprDebugInfo | None
|
||||
|
||||
|
||||
@ -2035,15 +2015,15 @@ def lower_sharding_computation(
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[MaybeSharding],
|
||||
out_shardings: Sequence[MaybeSharding],
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
lowering_parameters: mlir.LoweringParameters
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -2056,8 +2036,7 @@ def lower_sharding_computation(
|
||||
auto_spmd_lowering = check_if_any_auto(
|
||||
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
|
||||
closed_jaxpr.jaxpr.debug_info)
|
||||
all_args_info = AllArgsInfo(global_in_avals, closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
kept_var_idx, name_stack) = _dce_jaxpr(
|
||||
@ -2109,7 +2088,7 @@ def lower_sharding_computation(
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
|
||||
if xla_extension_version < 241:
|
||||
gs = GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
@ -2720,15 +2699,12 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return None, compile_options
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = compiler.compile_or_get_cached(
|
||||
backend, computation, dev, compile_options, host_callbacks)
|
||||
return xla_executable, compile_options
|
||||
return xla_executable
|
||||
|
||||
|
||||
def _maybe_get_and_check_in_shardings(
|
||||
@ -2758,21 +2734,16 @@ def _maybe_get_and_check_in_shardings(
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
# TODO(yashkatariya): Remove the if branch for abstract_token once
|
||||
# choosing input shardings by XLA is enabled again.
|
||||
if aval is core.abstract_token:
|
||||
new_in_shardings.append(orig)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
return new_in_shardings
|
||||
|
||||
|
||||
@ -2863,7 +2834,6 @@ class UnloadedMeshExecutable:
|
||||
self.in_layouts, self.out_layouts,
|
||||
self.all_args_info, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
hlo: ir.Module,
|
||||
@ -2916,24 +2886,12 @@ class UnloadedMeshExecutable:
|
||||
mesh = i.mesh # type: ignore
|
||||
break
|
||||
|
||||
xla_executable, compile_options = _cached_compilation(
|
||||
xla_executable = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
compiler_options_keys, compiler_options_values)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
semantics_in_shardings = SemanticallyEqualShardings(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
semantics_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
return _compile_replicated_mesh_executable_from_hlo(
|
||||
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
|
||||
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
|
||||
compile_options, tuple(host_callbacks), bool(unordered_effects),
|
||||
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
||||
pmap_nreps)
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
@ -3052,28 +3010,22 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return self.xla_executable
|
||||
|
||||
def call(self, *args):
|
||||
args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
if self._all_args_info is None:
|
||||
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
kept_args = args_after_dce
|
||||
ref_avals = self.in_avals
|
||||
in_shardings = self._in_shardings
|
||||
in_layouts = self._in_layouts
|
||||
debug_info = None
|
||||
else:
|
||||
kept_args = args
|
||||
ref_avals = self._all_args_info.in_avals
|
||||
iter_in_shardings = iter(self._in_shardings)
|
||||
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
|
||||
for i, s in enumerate(self._all_args_info.in_shardings)]
|
||||
iter_in_layouts = iter(self._in_layouts)
|
||||
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
|
||||
for i, s in enumerate(self._all_args_info.in_layouts)]
|
||||
debug_info = self._all_args_info.debug_info
|
||||
|
||||
arg_avals = map(xla.abstractify, kept_args)
|
||||
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
|
||||
all_arg_avals = map(xla.abstractify, kept_args)
|
||||
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_array_xla_sharding_layout_match(kept_args, in_shardings,
|
||||
in_layouts, debug_info)
|
||||
check_array_xla_sharding_layout_match(
|
||||
args_after_dce, self._in_shardings, self._in_layouts, debug_info,
|
||||
self._kept_var_idx)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
@ -3182,35 +3134,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _compile_replicated_mesh_executable_from_hlo(
|
||||
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
||||
semantics_out_shardings, auto_spmd_lowering, compile_options,
|
||||
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
||||
backend, da, committed, pmap_nreps):
|
||||
assert not auto_spmd_lowering
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
out_shardings = semantics_out_shardings.shardings
|
||||
|
||||
kept_var_idx = set(kept_var_idx)
|
||||
# Will compute out_handler with executable information.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=False, name=name, computation=computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
device_assignment=da, ordered_effects=ordered_effects,
|
||||
in_avals=global_in_avals,
|
||||
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
|
||||
out_avals=global_out_avals, out_shardings=out_shardings,
|
||||
committed=committed, pmap_nreps=pmap_nreps)
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
global_out_avals, in_shardings, out_shardings,
|
||||
auto_spmd_lowering, kept_var_idx,
|
||||
(None,) * len(global_in_avals),
|
||||
(None,) * len(global_out_avals))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
||||
@ -3231,16 +3154,22 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
||||
|
||||
def check_array_xla_sharding_layout_match(
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
args_after_dce,
|
||||
in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None,
|
||||
kept_var_idx: set[int]) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
||||
jaxpr_debug_info.arg_names)
|
||||
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
|
||||
arg_names = (
|
||||
[""] * len(args_after_dce) if jaxpr_debug_info is None
|
||||
else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
|
||||
if i in kept_var_idx]
|
||||
)
|
||||
errors = []
|
||||
num_errors = 5
|
||||
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
|
||||
arg_names):
|
||||
for arg, xs, xl, name in safe_zip(
|
||||
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
if is_unspecified_or_auto(xs):
|
||||
@ -3271,8 +3200,9 @@ def check_array_xla_sharding_layout_match(
|
||||
arg.layout.device_local_layout != xl):
|
||||
errors.append(
|
||||
("Got input layout(s) that compiled object was called with: "
|
||||
f"{arg.layout} and layout(s) the computation was compiled "
|
||||
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
f"{arg.layout.device_local_layout} and layout(s) the computation was "
|
||||
f"compiled with: {xl} for arg {name} with "
|
||||
f"shape: {arg.aval.str_short()}",
|
||||
'layout'))
|
||||
|
||||
if errors:
|
||||
@ -3362,12 +3292,3 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
|
||||
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
replicate: bool=False) -> list[xc.ArrayImpl]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return [jax.device_put(x, device) for device in devices]
|
||||
else:
|
||||
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
|
||||
|
@ -69,7 +69,7 @@ _no_operand_sentinel = object()
|
||||
@api_boundary
|
||||
def switch(index, branches: Sequence[Callable], *operands,
|
||||
operand=_no_operand_sentinel):
|
||||
"""Apply exactly one of ``branches`` given by ``index``.
|
||||
"""Apply exactly one of the ``branches`` given by ``index``.
|
||||
|
||||
If ``index`` is out of bounds, it is clamped to within bounds.
|
||||
|
||||
|
@ -49,7 +49,7 @@ def _split_root_args(args, const_lengths):
|
||||
|
||||
@api_boundary
|
||||
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
|
||||
"""Differentiably solve for a roots of a function.
|
||||
"""Differentiably solve for the roots of a function.
|
||||
|
||||
This is a low-level routine, mostly intended for internal use in JAX.
|
||||
Gradients of custom_root() are defined with respect to closed-over variables
|
||||
|
@ -22,7 +22,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
|
||||
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -625,14 +625,14 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
||||
|
||||
_precision_strings: dict[Any, Precision] = {}
|
||||
|
||||
# TODO(b/328046715): pytype appears unable to handle overriding __new__ in an
|
||||
# enum class. Doing this crashes Pytype. For now, just write an explicit type
|
||||
# for type checkers.
|
||||
# TODO(b/333851820): pytype does not properly handle _missing_ in enums.
|
||||
# We work around that by defining `Precision` as a normal class.
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class Precision:
|
||||
DEFAULT: Precision
|
||||
HIGH: Precision
|
||||
HIGHEST: Precision
|
||||
DEFAULT: ClassVar[Precision]
|
||||
HIGH: ClassVar[Precision]
|
||||
HIGHEST: ClassVar[Precision]
|
||||
|
||||
def __new__(cls, value: Precision | int | str | None) -> Precision:
|
||||
raise NotImplementedError
|
||||
@ -646,6 +646,7 @@ if TYPE_CHECKING:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
|
||||
class Precision(enum.Enum):
|
||||
"""Precision enum for lax functions
|
||||
|
||||
@ -664,23 +665,21 @@ else:
|
||||
Slowest but most accurate. Performs computations in float32 or float64
|
||||
as applicable. Aliases: ``'highest'``, ``'float32'``.
|
||||
"""
|
||||
|
||||
DEFAULT = 0
|
||||
HIGH = 1
|
||||
HIGHEST = 2
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> Precision | None:
|
||||
return _precision_strings.get(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
return f'{self.__class__.__name__}.{self.name}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
# You can't define __new__ on an enum class directly, but you can monkey-patch
|
||||
# it after the fact. Another way to do this might be using a metaclass.
|
||||
def _precision_new(cls, value: Precision | int | str | None) -> Precision:
|
||||
return super(Precision, cls).__new__(cls, _precision_strings.get(value, value))
|
||||
|
||||
Precision.__new__ = _precision_new
|
||||
|
||||
|
||||
_precision_strings['highest'] = Precision.HIGHEST
|
||||
_precision_strings['float32'] = Precision.HIGHEST
|
||||
|
@ -79,7 +79,7 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
>>> print(y)
|
||||
[0. 0.16666667 0.33333334 0.5 ]
|
||||
|
||||
Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with `device2` and `device3`,
|
||||
Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with ``device2`` and ``device3``,
|
||||
|
||||
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
|
||||
>>> print(y)
|
||||
|
@ -232,7 +232,7 @@ class GatherDimensionNumbers(NamedTuple):
|
||||
in the output of the gather. Must be a tuple of integers in ascending
|
||||
order.
|
||||
start_index_map: for each dimension in `start_indices`, gives the
|
||||
corresponding dimension in `operand` that is to be sliced. Must be a
|
||||
corresponding dimension in the `operand` that is to be sliced. Must be a
|
||||
tuple of integers with size equal to `start_indices.shape[-1]`.
|
||||
|
||||
Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is
|
||||
@ -261,8 +261,8 @@ class GatherScatterMode(enum.Enum):
|
||||
will be discarded.
|
||||
PROMISE_IN_BOUNDS:
|
||||
The user promises that indices are in bounds. No additional checking will be
|
||||
performed. In practice, with the current XLA implementation this means
|
||||
that, out-of-bounds gathers will be clamped but out-of-bounds scatters will
|
||||
performed. In practice, with the current XLA implementation this means
|
||||
that out-of-bounds gathers will be clamped but out-of-bounds scatters will
|
||||
be discarded. Gradients will not be correct if indices are out-of-bounds.
|
||||
"""
|
||||
CLIP = enum.auto()
|
||||
|
@ -72,16 +72,18 @@ class Layout:
|
||||
)
|
||||
if not isinstance(
|
||||
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
'Invalid value received for the device_local_layout argument.'
|
||||
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an instance'
|
||||
f' of `DeviceLocalLayout`. Got {device_local_layout}')
|
||||
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an'
|
||||
f' instance of `DeviceLocalLayout`. Got {device_local_layout} of'
|
||||
f' type {type(device_local_layout)}'
|
||||
)
|
||||
if not isinstance(
|
||||
sharding, (Sharding, type(None), AutoSharding)):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
'Invalid value received for the sharding argument. Expected values'
|
||||
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
|
||||
f' {sharding}')
|
||||
f' {sharding} of type {type(sharding)}')
|
||||
|
||||
self.device_local_layout = device_local_layout
|
||||
self.sharding = sharding
|
||||
|
@ -714,9 +714,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
return pxla.lower_sharding_computation(
|
||||
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, in_avals, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters,
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -20,6 +20,7 @@ from functools import partial
|
||||
import operator
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -35,6 +36,12 @@ from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
|
||||
class Unspecified:
|
||||
def __repr__(self):
|
||||
return "_UNSPECIFIED"
|
||||
_UNSPECIFIED = Unspecified()
|
||||
|
||||
|
||||
# activations
|
||||
|
||||
@custom_jvp
|
||||
@ -173,6 +180,38 @@ def sigmoid(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return lax.logistic(x)
|
||||
|
||||
@jax.jit
|
||||
def sparse_sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Sparse sigmoid activation function.
|
||||
|
||||
Computes the function:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
||||
0, & x \leq -1\\
|
||||
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
||||
1, & 1 \leq x
|
||||
\end{cases}
|
||||
|
||||
This is the twin function of the ``sigmoid`` activation ensuring a zero output
|
||||
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
|
||||
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
|
||||
|
||||
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
|
||||
<https://arxiv.org/abs/1901.02324>`_.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
|
||||
See also:
|
||||
:func:`sigmoid`
|
||||
"""
|
||||
return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0)
|
||||
|
||||
@jax.jit
|
||||
def silu(x: ArrayLike) -> Array:
|
||||
r"""SiLU (aka swish) activation function.
|
||||
@ -454,7 +493,7 @@ logsumexp = _logsumexp
|
||||
def log_softmax(x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None = None) -> Array:
|
||||
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
|
||||
r"""Log-Softmax function.
|
||||
|
||||
Computes the logarithm of the :code:`softmax` function, which rescales
|
||||
@ -469,8 +508,6 @@ def log_softmax(x: ArrayLike,
|
||||
axis: the axis or axes along which the :code:`log_softmax` should be
|
||||
computed. Either an integer or a tuple of integers.
|
||||
where: Elements to include in the :code:`log_softmax`.
|
||||
initial: The minimum value used to shift the input array. Must be present
|
||||
when :code:`where` is not None.
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
@ -478,10 +515,15 @@ def log_softmax(x: ArrayLike,
|
||||
See also:
|
||||
:func:`softmax`
|
||||
"""
|
||||
if initial is not _UNSPECIFIED:
|
||||
# Added 2024-4-10
|
||||
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
del initial
|
||||
numpy_util.check_arraylike("log_softmax", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True)
|
||||
x_safe = x_arr if where is None else jnp.where(where, x_arr, initial)
|
||||
x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True)
|
||||
x_safe = x_arr if where is None else jnp.where(where, x_arr, -jnp.inf)
|
||||
shifted = x_safe - lax.stop_gradient(x_max)
|
||||
shifted_logsumexp = jnp.log(
|
||||
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
|
||||
@ -496,7 +538,7 @@ def log_softmax(x: ArrayLike,
|
||||
def softmax(x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None = None) -> Array:
|
||||
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
|
||||
r"""Softmax function.
|
||||
|
||||
Computes the function which rescales elements to the range :math:`[0, 1]`
|
||||
@ -511,8 +553,6 @@ def softmax(x: ArrayLike,
|
||||
softmax output summed across these dimensions should sum to :math:`1`.
|
||||
Either an integer or a tuple of integers.
|
||||
where: Elements to include in the :code:`softmax`.
|
||||
initial: The minimum value used to shift the input array. Must be present
|
||||
when :code:`where` is not None.
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
@ -520,13 +560,18 @@ def softmax(x: ArrayLike,
|
||||
See also:
|
||||
:func:`log_softmax`
|
||||
"""
|
||||
if initial is not _UNSPECIFIED:
|
||||
# Added 2024-4-10
|
||||
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
del initial
|
||||
if config.softmax_custom_jvp.value:
|
||||
# mypy is confused by the `functools.partial` application in the definition
|
||||
# of `_softmax` and incorrectly concludes that `_softmax` returns
|
||||
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
|
||||
return _softmax(x, axis, where, initial) # type: ignore[return-value]
|
||||
return _softmax(x, axis, where) # type: ignore[return-value]
|
||||
else:
|
||||
return _softmax_deprecated(x, axis, where, initial)
|
||||
return _softmax_deprecated(x, axis, where)
|
||||
|
||||
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
@ -534,7 +579,7 @@ def _softmax(
|
||||
x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None = None) -> Array:
|
||||
initial: ArrayLike | None = -jnp.inf) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
x_safe = x if where is None else jnp.where(where, x, initial)
|
||||
unnormalized = jnp.exp(x_safe - x_max)
|
||||
@ -553,7 +598,7 @@ def _softmax_deprecated(
|
||||
x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None = None) -> Array:
|
||||
initial: ArrayLike | None = -jnp.inf) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
x_safe = x if where is None else jnp.where(where, x, initial)
|
||||
unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max))
|
||||
|
@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:
|
||||
|
||||
|
||||
def _clip(number: ArrayLike,
|
||||
min: ArrayLike | None = None, max: ArrayLike | None = None,
|
||||
out: None = None) -> Array:
|
||||
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
|
||||
"""Return an array whose values are limited to a specified range.
|
||||
|
||||
Refer to :func:`jax.numpy.clip` for full documentation."""
|
||||
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
|
||||
return lax_numpy.clip(number, min=min, max=max)
|
||||
|
||||
|
||||
def _transpose(a: Array, *args: Any) -> Array:
|
||||
|
@ -66,7 +66,10 @@ from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike, DimSize, DuckTypedArray,
|
||||
DType, DTypeLike, Shape, DeprecatedArg
|
||||
)
|
||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||
ceil_of_ratio, partition_list,
|
||||
canonicalize_axis as _canonicalize_axis,
|
||||
@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
|
||||
axis: int = 0) -> list[Array]:
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@util.implements(np.clip, skip_params=['out'])
|
||||
|
||||
_DEPRECATED_CLIP_ARG = DeprecatedArg()
|
||||
@util.implements(
|
||||
np.clip,
|
||||
skip_params=['a', 'a_min'],
|
||||
extra_params=_dedent("""
|
||||
x : array_like
|
||||
Array containing elements to clip.
|
||||
min : array_like, optional
|
||||
Minimum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``min`` is broadcast against x.
|
||||
max : array_like, optional
|
||||
Maximum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``max`` is broadcast against x.
|
||||
""")
|
||||
)
|
||||
@jit
|
||||
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
|
||||
a_max: ArrayLike | None = None, out: None = None) -> Array:
|
||||
util.check_arraylike("clip", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
|
||||
if a_min is None and a_max is None:
|
||||
raise ValueError("At most one of a_min and a_max may be None")
|
||||
if a_min is not None:
|
||||
a = ufuncs.maximum(a_min, a)
|
||||
if a_max is not None:
|
||||
a = ufuncs.minimum(a_max, a)
|
||||
return asarray(a)
|
||||
def clip(
|
||||
x: ArrayLike | None = None, # Default to preserve backwards compatability
|
||||
/,
|
||||
min: ArrayLike | None = None,
|
||||
max: ArrayLike | None = None,
|
||||
*,
|
||||
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
|
||||
) -> Array:
|
||||
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
|
||||
x = a if not isinstance(a, DeprecatedArg) else x
|
||||
if x is None:
|
||||
raise ValueError("No input was provided to the clip function.")
|
||||
min = a_min if not isinstance(a_min, DeprecatedArg) else min
|
||||
max = a_max if not isinstance(a_max, DeprecatedArg) else max
|
||||
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
|
||||
warnings.warn(
|
||||
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
|
||||
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
util.check_arraylike("clip", x)
|
||||
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
|
||||
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
|
||||
warnings.warn(
|
||||
"Clip received a complex value either through the input or the min/max "
|
||||
"keywords. Complex values have no ordering and cannot be clipped. "
|
||||
"Attempting to clip using complex numbers is deprecated and will soon "
|
||||
"raise a ValueError. Please convert to a real value or array by taking "
|
||||
"the real or imaginary components via jax.numpy.real/imag respectively.",
|
||||
DeprecationWarning, stacklevel=2,
|
||||
)
|
||||
if min is not None:
|
||||
x = ufuncs.maximum(min, x)
|
||||
if max is not None:
|
||||
x = ufuncs.minimum(max, x)
|
||||
return asarray(x)
|
||||
|
||||
@util.implements(np.around, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('decimals',))
|
||||
@ -2217,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
""")
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
||||
util.check_arraylike("astype", x)
|
||||
x_arr = asarray(x)
|
||||
del copy # unused in JAX
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(float_)
|
||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||
return lax.convert_element_type(x, dtype)
|
||||
# convert_element_type(complex, bool) has the wrong semantics.
|
||||
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
|
||||
return (x_arr != _lax_const(x_arr, 0))
|
||||
return lax.convert_element_type(x_arr, dtype)
|
||||
|
||||
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||
@ -2442,9 +2493,10 @@ def fromiter(*args, **kwargs):
|
||||
is later modified in-place, it may lead to undefined behavior when using
|
||||
the associated JAX array.
|
||||
""")
|
||||
def from_dlpack(x: Any) -> Array:
|
||||
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
|
||||
copy: bool | None = None) -> Array:
|
||||
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
|
||||
return from_dlpack(x)
|
||||
return from_dlpack(x, device=device, copy=copy)
|
||||
|
||||
@util.implements(np.fromfunction)
|
||||
def fromfunction(function: Callable[..., Array], shape: Any,
|
||||
@ -4889,8 +4941,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
"with type {} at position {}, indexer value {}")
|
||||
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
|
||||
|
||||
msg = "Indexing mode not yet supported. Open a feature request!\n{}"
|
||||
raise IndexError(msg.format(idx))
|
||||
raise IndexError("Indexing mode not yet supported. Got unsupported indexer "
|
||||
f"at position {idx_pos}: {i!r}")
|
||||
|
||||
if len(gather_indices) == 0:
|
||||
gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user