Merge branch 'main' into refs-in-vjps

This commit is contained in:
Dougal Maclaurin 2024-04-12 15:25:37 -04:00 committed by GitHub
commit f313a46916
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
250 changed files with 5612 additions and 3345 deletions

171
.github/workflows/upstream-nightly.yml vendored Normal file
View File

@ -0,0 +1,171 @@
name: CI - with Numpy/Scipy nightly wheels (nightly)
# This configures a github action that runs the JAX test suite against nightly development builds
# of numpy and scipy, in order to catch issues with new package versions prior to their release.
# Unlike our other CI, this is one that we expect to fail frequently, and so we don't run it against
# every commit and PR in the repository. Rather, we run it on a schedule, and failures lead to an
# issue being created or updated.
# Portions of this adapted from https://github.com/pydata/xarray/blob/main/.github/workflows/upstream-dev-ci.yaml
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/upstream-nightly.yml'
jobs:
upstream-dev:
runs-on: ubuntu-20.04-16core
permissions:
contents: read
checks: write # for upload-artifact
defaults:
run:
shell: bash -l {0}
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
outputs:
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
pip install pytest-reportlog
- name: Install numpy & scipy development versions
run: |
pip install \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
--no-deps \
--pre \
--upgrade \
numpy \
scipy
- name: Install JAX
run: |
pip install .[ci]
- name: Run tests
if: success()
id: status
env:
JAX_NUM_GENERATED_CASES: 1
JAX_ENABLE_X64: true
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short -rf --maxfail=20 \
--report-log output-${{ matrix.python-version }}-log.jsonl \
tests \
|| (
echo 'ARTIFACTS_AVAILABLE=true' >> $GITHUB_OUTPUT && false
)
- name: Upload artifacts
if: |
failure()
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'google/jax'
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # ratchet: actions/upload-artifact@v4
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl
retention-days: 5
report:
name: report
needs: upstream-dev
permissions:
contents: read
issues: write
if: |
failure()
&& github.event_name == 'schedule'
&& needs.upstream-dev.outputs.artifacts_availability == 'true'
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
with:
python-version: "3.x"
- uses: actions/download-artifact@c850b930e6ba138125429b7e5c93fc707a7f8427 # ratchet:actions/download-artifact@v4
with:
path: /tmp/workspace/logs
- name: install requirements
run: |
python -m pip install pytest
- name: Move all log files into a single directory
run: |
rsync -a /tmp/workspace/logs/output-*/ ./logs
ls -R ./logs
cat logs/*.jsonl > pytest-logs.txt
python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt
- name: Report failures
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const parsed_logs = fs.readFileSync('parsed-logs.txt', 'utf8');
const title = "⚠️ Nightly upstream-dev CI failed ⚠️"
const workflow_url = `https://github.com/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
const issue_body = `[Workflow Run URL](${workflow_url})\n${parsed_logs}`
// Run GraphQL query against GitHub API to find the most recent open issue used for reporting failures
const query = `query($owner:String!, $name:String!, $creator:String!, $label:String!){
repository(owner: $owner, name: $name) {
issues(first: 1, states: OPEN, filterBy: {createdBy: $creator, labels: [$label]}, orderBy: {field: CREATED_AT, direction: DESC}) {
edges {
node {
body
id
number
}
}
}
}
}`;
const variables = {
owner: context.repo.owner,
name: context.repo.repo,
label: 'CI',
creator: "github-actions[bot]"
}
const result = await github.graphql(query, variables)
// If no issue is open, create a new issue,
// else update the body of the existing issue.
if (result.repository.issues.edges.length === 0) {
github.rest.issues.create({
owner: variables.owner,
repo: variables.name,
body: issue_body,
title: title,
labels: [variables.label]
})
} else {
github.rest.issues.update({
owner: variables.owner,
repo: variables.name,
issue_number: result.repository.issues.edges[0].node.number,
body: issue_body
})
}

View File

@ -9,7 +9,7 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-ast
- id: check-merge-conflict
@ -26,7 +26,7 @@ repos:
files: \.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.3.5
hooks:
- id: ruff
@ -40,7 +40,7 @@ repos:
args: [--config=pyproject.toml]
- repo: https://github.com/mwouts/jupytext
rev: v1.16.0
rev: v1.16.1
hooks:
- id: jupytext
args: [--sync]

View File

@ -6,7 +6,7 @@
version: 2
build:
os: "ubuntu-20.04"
os: "ubuntu-22.04"
tools:
python: "3.9"

View File

@ -8,10 +8,25 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.27
* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
is deprecated; empty inputs to softmax are now supported without setting this.
## jaxlib 0.4.27
@ -25,6 +40,13 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral
branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
changed](https://github.com/google/jax/issues/19085) so that
mapping over keys results in random generation only from the first
key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays
rather than `jax.random.PRNGKey`.
* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
@ -40,6 +62,8 @@ Remember to align the itemized text with the first line of an item within a list
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`#20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
@ -47,14 +71,13 @@ Remember to align the itemized text with the first line of an item within a list
* The previously-deprecated imports `jax.interpreters.ad.config` and
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.
* JAX export does not support anymore older serialization version. Version 9
* JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific
JAX serialization version lower than 9.
## jaxlib 0.4.26 (April 3, 2024)
* Changes
@ -131,8 +154,8 @@ Remember to align the itemized text with the first line of an item within a list
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
The scope of a symbolic expression `e` can be read with `e.scope` and passed in
to the above functions to direct them to construct symbolic expressions in
The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
@ -313,7 +336,7 @@ Remember to align the itemized text with the first line of an item within a list
* Bug fixes
* Only process 0 in a multicontroller distributed JAX program will write
persistent compilation cache entries. This fixes write contention if the
cache is placed on a network filesystem such as GCS.
cache is placed on a network file system such as GCS.
* The version check for cusolver and cufft no longer considers the patch
versions when determining if the installed version of these libraries is at
least as new as the versions against which JAX was built.
@ -1441,7 +1464,7 @@ Changes:
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the

View File

@ -33,9 +33,7 @@ from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
partial = functools.partial

View File

@ -15,12 +15,12 @@
import google_benchmark as benchmark
from jax import config
import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax.experimental import export
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
@benchmark.register

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,528 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial, reduce
import math
from typing import Tuple
import jax
import jax.numpy as jnp
from build import gpu_ops
from jax import core, dtypes
from jax.core import ShapedArray
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.interpreters import batching, mlir, xla
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jaxlib.hlo_helpers import custom_call
from jax._src import dispatch
######################################################################
# Created Primitives for unsharded RMS norm reference implementation #
######################################################################
# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
def rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
####################
# Lowering to MLIR #
####################
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
def element_type_to_descriptor_type_mapping(element_type):
_element_type_to_descriptor_type_mapping = {
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
ir.F16Type.get(): gpu_ops.ElementType.F16,
ir.F32Type.get(): gpu_ops.ElementType.F32,
ir.F64Type.get(): gpu_ops.ElementType.F64,
}
return _element_type_to_descriptor_type_mapping.get(element_type)
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_element_type = (
ir.F32Type.get()
if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
else x_type.element_type
)
n2 = math.prod(w_shape)
n1 = math.prod(x_shape) // n2
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
0, # unused
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
operands=[x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
).results
return out
mlir.register_lowering(
_rms_norm_fwd_p,
_rms_norm_fwd_cuda_lowering,
platform="gpu",
)
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_type = ir.RankedTensorType(invvar.type)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
part_grad_shape[0],
)
out = custom_call(
b"rms_backward_affine",
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
],
operands=[grad_output, invvar, x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
).results
return out
mlir.register_lowering(
_rms_norm_bwd_p,
_rms_norm_bwd_cuda_lowering,
platform="gpu",
)
#######################
# Abstract evaluation #
#######################
def _rms_norm_fwd_abstract(x, weight, eps):
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
iv_dtype = dtypes.canonicalize_dtype(x.dtype)
if iv_dtype in [jnp.float16, jnp.bfloat16]:
iv_dtype = jnp.float32
n2 = math.prod(weight.shape)
n1 = math.prod(x.shape) // n2
return (
ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output
ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar
)
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
x_dtype = dtypes.canonicalize_dtype(x.dtype)
n2 = reduce(lambda x, y: x * y, weight.shape)
n1 = reduce(lambda x, y: x * y, x.shape) // n2
part_grad_shape = (16, n2)
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert invvar.shape == (n1,)
assert (
iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
)
assert grad_output.named_shape == x.named_shape
weight_named_shape = (
weight.named_shape if weight.named_shape else grad_output.named_shape
)
return (
ShapedArray(
x.shape, x_dtype, named_shape=x.named_shape
), # grad input
ShapedArray(
weight.shape, w_dtype, named_shape=weight_named_shape
), # grad weight
ShapedArray(
part_grad_shape, iv_dtype, named_shape=weight_named_shape
), # part grad
)
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
#######################################
# Top-level interface with custom vjp #
#######################################
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
output, _ = rms_norm_fwd(x, weight, eps=eps)
return output
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
###########################################################
# Create primitives for RMS norm with custom_partitioning #
###########################################################
def _check_valid_batch_dims(bdims):
"""
Assert out non-supported bath dims
"""
for dim in bdims:
assert dim in [0, None], \
"Currently only support batch_dim in [0, None], " \
f"but got {dim=}"
def register_primitive(cls):
"""
register jax primitive
The order of calls. Each operation is composed of two primitives: Inner and Outer.
Inner, only the basic to wrap the custom_call itself.
- impl to XLA custom_call in C.
- abstract to know the static shapes
- lower to StableHLO XLA custom_call.
Outer, mostly all the rest:
- impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
- abstract: same
- lower to StableHLO custom_p. (XLA will call the python callback from it)
- custom_p
- vmap: could be added here.
VJP is based on Outer, but not handled in this function.
"""
def name_of_wrapper_p():
return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract)
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition)
mlir.register_lowering(outer_p,
mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
cls.outer_primitive = outer_p
class RmsNormFwdClass:
name = "rms_forward_affine_mixed_dtype"
multiple_results = True
impl_static_args = (2,) # eps
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument
return _rms_norm_fwd_abstract(x_aval, gamma_aval, **kwargs)
@staticmethod
def lowering(ctx, x, gamma, *, eps):
return _rms_norm_fwd_cuda_lowering(ctx, x, gamma, eps)
@staticmethod
def impl(x, gamma, eps):
assert RmsNormFwdClass.inner_primitive is not None
out, rsigma = RmsNormFwdClass.inner_primitive.bind(x, gamma, eps=eps)
return out, rsigma
@staticmethod
def batcher(batched_args, batch_dims, *, eps):
_check_valid_batch_dims(batch_dims)
assert RmsNormFwdClass.outer_primitive is not None
x, gamma = batched_args
x_bdim, _ = batch_dims
out_bdims = x_bdim, x_bdim
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
del eps, result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# partition() will force all dims to be replicated except the
# first dim of x that will be kept as is.
x_spec = arg_infos[0].sharding.spec
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
return (output_sharding, invvar_sharding)
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
del result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
x_spec = arg_infos[0].sharding.spec
# We only support sharding on the batch dimensions.
# Force sharding on all others dimensions with None.
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
output_shardings = (arg_shardings[0], invvar_sharding)
# Sharded_impl only accepts positional arugments
# And they should be Jax traceable variables
impl = partial(RmsNormFwdClass.impl, eps=eps)
return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormFwdClass)
class RmsNormBwdClass:
name = "rms_norm_bwd"
multiple_results = True
impl_static_args = (4,) # eps
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(grad_output, invvar, x, weight, eps): # pylint: disable=unused-argument
return _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps)
@staticmethod
def lowering(ctx, grad_output, invvar, x, weight, eps):
return _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps)
@staticmethod
def impl(grad_output, invvar, x, weight, eps):
assert RmsNormBwdClass.inner_primitive is not None
gx, gw, part_grad = RmsNormBwdClass.inner_primitive.bind(grad_output, invvar, x, weight, eps=eps)
return gx, gw, part_grad
@staticmethod
def batcher(batched_args, batch_dims, *, eps):
# TODO: Add to the tutorial!
_check_valid_batch_dims(batch_dims)
assert RmsNormBwdClass.outer_primitive is not None
x, gamma = batched_args
x_bdim, _ = batch_dims
out_bdims = x_bdim, x_bdim
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
del eps, result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
assert len(invvar_info.shape) == 1
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# partition() will force all dims to be replicated except the batch dimension.
x_spec = x_info.sharding.spec
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
return (output_sharding, invvar_sharding, output_sharding, )
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
del result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
assert len(invvar_info.shape) == 1
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# We only support sharding on the batch dimensions.
# Force sharding on all others dimensions with None.
# Also force gx, x and invvar to have the same batch sharding/replication.
x_spec = x_info.sharding.spec
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(x_spec[0],)),
NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(None, None)))
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
# Sharded_impl only accepts positional arugments
# And they should be Jax traceable variables
def sharded_impl(g, invvar, x, weight):
grad_input, grad_weight, part_grad = RmsNormBwdClass.impl(
g, invvar, x, weight, eps=eps
)
# We need to sum the weight gradient from all partition.
# when the input is sharded and weights are replicated
global_weight = grad_weight
if x_spec[0]:
global_weight = jax.lax.psum(grad_weight, x_spec[0])
return grad_input, global_weight, part_grad
return mesh, sharded_impl, output_shardings, arg_shardings
register_primitive(RmsNormBwdClass)
def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def custom_p_rms_norm(x, weight, eps=1e-05):
output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
return output
def custom_p_rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
g, invvar, x, weight, eps=eps)
return grad_input, grad_weight
custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd)
########
# Test #
########
import jax
per_core_batch_size = 4
seq_len = 512
emb_dim = 512
assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?"
x = jax.random.normal(
jax.random.PRNGKey(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.float16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.float16)
def ref_loss(x, weight):
predictions = rms_norm(x, weight)
return -jnp.mean(predictions**2)
ref_out = jax.grad(ref_loss, argnums=(0, 1))(x, weight)
def custom_p_loss(x, weight):
predictions = custom_p_rms_norm(x, weight)
return -jnp.mean(predictions**2)
with Mesh(jax.local_devices(), ("x",)):
def run_and_verify(loss):
pjitted = pjit(
jax.grad(loss, argnums=(0, 1)),
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
# Shard the output by batch dimension and replicate weight grad on all devices.
out_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
)
hlo = pjitted.lower(x, weight).compile().as_text()
out = pjitted(x, weight)
print(hlo)
assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
if "all-gather-start" in hlo:
print("NOT OPTIMIZED, ALL_GATHER in the graph!")
return out
custom_p_out = run_and_verify(custom_p_loss)
for r, o in zip(ref_out, custom_p_out):
print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))

View File

@ -6,7 +6,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -20,7 +20,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.0
# jupytext_version: 1.16.1
# kernelspec:
# display_name: Python 3
# name: python3

13
docs/build_custom_gpu.sh Normal file
View File

@ -0,0 +1,13 @@
python -m pip install pybind11==2.10.1
mkdir -p build
touch build/__init__.py
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
#python_include_path=$(python -c 'from distutils.sysconfig import get_python_inc;print(get_python_inc())')
echo pybind_include_path=$pybind_include_path
echo python_executable=$python_executable
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}3-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}3-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}3-config --extension-suffix)

View File

@ -12,14 +12,14 @@ JAX offers flags and context managers that enable catching errors more easily.
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
### Example(s)
```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
@ -47,14 +47,14 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
You can disable JIT-compilation by:
* setting the `JAX_DISABLE_JIT=True` environment variable;
* adding `from jax import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
* adding `jax.config.update("jax_disable_jit", True)` near the top of your main file;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
### Examples
```python
from jax import config
config.update("jax_disable_jit", True)
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)

View File

@ -82,8 +82,8 @@ Click [here](checkify_guide) to learn more!
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
from jax import config
config.update("jax_debug_nans", True)
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y

View File

@ -206,7 +206,7 @@ python build/build.py --configure_only
You may pass additional options to `build.py` to configure the build; see the
`jaxlib` build documentation for details.
By default the Bazel build runs the JAX tests using `jaxlib` built form source.
By default the Bazel build runs the JAX tests using `jaxlib` built from source.
To run JAX tests, run:
```

View File

@ -413,7 +413,7 @@ speed of code using JAX:
use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see
`Double (64 bit) precision`_) for a fair comparison.
4. **Transferring data between CPUs and accelerators takes time.** If you only
want to measure the how long it takes to evaluate a function, you may want to
want to measure how long it takes to evaluate a function, you may want to
transfer data to the device on which you want to run it first (see
:ref:`faq-data-placement`).
@ -814,6 +814,32 @@ computation at runtime. For example:
For more information on runtime callbacks and examples of their use,
see `External callbacks in JAX`_.
Why do some CUDA libraries fail to load/initialize?
---------------------------------------------------
When resolving dynamic libraries, JAX uses the usual `dynamic linker search pattern`_.
JAX sets :code:`RPATH` to point to the JAX-relative location of the
pip-installed NVIDIA CUDA packages, preferring them if installed. If :code:`ld.so`
cannot find your CUDA runtime libraries along its usual search path, then you
must include the paths to those libraries explicitly in :code:`LD_LIBRARY_PATH`.
The easiest way to ensure your CUDA files are discoverable is to simply install
the :code:`nvidia-*-cu12` pip packages, which are included in the standard
:code:`jax[cuda_12]` install option.
Occasionally, even when you have ensured that your runtime libraries are discoverable,
there may still be some issues with loading or initializing them. A common cause of
such issues is simply having insufficient memory for CUDA library initialization at
runtime. This sometimes occurs because JAX will pre-allocate too large of a chunk of
currently available device memory for faster execution, occasionally resulting in
insufficient memory being left available for runtime CUDA library initialization.
This is especially likely when running multiple JAX instances, running JAX in
tandem with TensorFlow which performs its own pre-allocation, or when running
JAX on a system where the GPU is being heavily utilized by other processes. When
in doubt, try running the program again with reduced pre-allocation, either by
reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`,
or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please
see the page on `JAX GPU memory allocation`_.
.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
@ -822,3 +848,5 @@ see `External callbacks in JAX`_.
.. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function
.. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function
.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266
.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html

View File

@ -53,7 +53,7 @@ JAX Glossary of Terms
pytree
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more
general containers of array values in a uniform way. Refer to {ref}`working-with-pytrees`
general containers of array values in a uniform way. Refer to :ref:`working-with-pytrees`
for a more detailed discussion.
reverse-mode autodiff

45
docs/gpu_ops/gpu_ops.cpp Normal file
View File

@ -0,0 +1,45 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "kernels.h"
#include "pybind11_kernel_helpers.h"
namespace {
pybind11::dict RMSNormRegistrations() {
pybind11::dict dict;
dict["rms_forward_affine_mixed_dtype"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
dict["rms_backward_affine"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
return dict;
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
n1, n2, eps, x_type, w_type, part_grad_size});
});
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
.value("BF16", gpu_ops::ElementType::BF16)
.value("F16", gpu_ops::ElementType::F16)
.value("F32", gpu_ops::ElementType::F32)
.value("F64", gpu_ops::ElementType::F64);
}
} // namespace

View File

@ -0,0 +1,64 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header is not specific to our application and you'll probably want
// something like this for any extension you're building. This includes the
// infrastructure needed to serialize descriptors that are used with the
// "opaque" parameter of the GPU custom call. In our example we'll use this
// parameter to pass the size of our problem.
#ifndef _GPU_OPS_KERNEL_HELPERS_H_
#define _GPU_OPS_KERNEL_HELPERS_H_
#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
#define JAX_APEX_WARP_SIZE 32
namespace gpu_ops {
// https://en.cppreference.com/w/cpp/numeric/bit_cast
template <class To, class From>
typename std::enable_if<sizeof(To) == sizeof(From) &&
std::is_trivially_copyable<From>::value &&
std::is_trivially_copyable<To>::value,
To>::type
bit_cast(const From &src) noexcept {
static_assert(std::is_trivially_constructible<To>::value,
"This implementation additionally requires destination type to "
"be trivially constructible");
To dst;
memcpy(&dst, &src, sizeof(To));
return dst;
}
template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
}
template <typename T>
const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return bit_cast<const T *>(opaque);
}
} // namespace gpu_ops
#endif

44
docs/gpu_ops/kernels.h Normal file
View File

@ -0,0 +1,44 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef _GPU_OPS_KERNELS_H_
#define _GPU_OPS_KERNELS_H_
#include <cuda_runtime_api.h>
#include <cstddef>
#include <cstdint>
namespace gpu_ops {
enum ElementType { BF16, F16, F32, F64 };
struct RMSNormDescriptor {
int n1;
int n2;
double eps;
ElementType x_type;
ElementType w_type;
int part_grad_size;
};
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
const char *opaque,
std::size_t opaque_len);
void rms_backward_affine(cudaStream_t stream, void **buffers,
const char *opaque, std::size_t opaque_len);
} // namespace gpu_ops
#endif

View File

@ -0,0 +1,41 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header extends kernel_helpers.h with the pybind11 specific interface to
// serializing descriptors. It also adds a pybind11 function for wrapping our
// custom calls in a Python capsule. This is separate from kernel_helpers so
// that the CUDA code itself doesn't include pybind11. I don't think that this
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
#include <pybind11/pybind11.h>
#include "kernel_helpers.h"
namespace gpu_ops {
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
return pybind11::bytes(PackDescriptorAsString(descriptor));
}
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
} // namespace gpu_ops
#endif

View File

@ -0,0 +1,970 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "kernel_helpers.h"
#include "kernels.h"
#include "stdio.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <iostream>
namespace {
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, \
NAME, ...) \
switch (TYPEIN) { \
case gpu_ops::ElementType::F64: { \
using scalar_t_in = double; \
using accscalar_t = double; \
switch (TYPEOUT) { \
case gpu_ops::ElementType::F64: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F32: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F16: { \
using scalar_t_out = __half; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::BF16: { \
using scalar_t_out = __nv_bfloat16; \
__VA_ARGS__; \
break; \
} \
default: \
break; \
} \
break; \
} \
case gpu_ops::ElementType::F32: { \
using scalar_t_in = float; \
using accscalar_t = float; \
switch (TYPEOUT) { \
case gpu_ops::ElementType::F64: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F32: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F16: { \
using scalar_t_out = __half; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::BF16: { \
using scalar_t_out = __nv_bfloat16; \
__VA_ARGS__; \
break; \
} \
default: \
break; \
} \
break; \
} \
case gpu_ops::ElementType::F16: { \
using scalar_t_in = __half; \
using accscalar_t = float; \
switch (TYPEOUT) { \
case gpu_ops::ElementType::F64: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F32: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F16: { \
using scalar_t_out = __half; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::BF16: { \
using scalar_t_out = __nv_bfloat16; \
__VA_ARGS__; \
break; \
} \
default: \
break; \
} \
break; \
} \
case gpu_ops::ElementType::BF16: { \
using scalar_t_in = __nv_bfloat16; \
using accscalar_t = float; \
switch (TYPEOUT) { \
case gpu_ops::ElementType::F64: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F32: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::F16: { \
using scalar_t_out = __half; \
__VA_ARGS__; \
break; \
} \
case gpu_ops::ElementType::BF16: { \
using scalar_t_out = __nv_bfloat16; \
__VA_ARGS__; \
break; \
} \
default: \
break; \
} \
break; \
} \
default: \
break; \
}
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U &mu, U &sigma2, U &count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template <typename U> __device__ void cuRMSOnlineSum(const U curr, U &sigma2) {
sigma2 = sigma2 + curr * curr;
}
template <typename U>
__device__ void cuChanRMSOnlineSum(const U sigma2B, U &sigma2) {
sigma2 = sigma2 + sigma2B;
}
template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
const int n2, const int i1, U &mu, U &sigma2,
U *buf, bool rms_only) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T *lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l + k]);
if (!rms_only) {
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
} else {
cuRMSOnlineSum<U>(curr, sigma2);
}
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
if (!rms_only) {
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
} else {
cuRMSOnlineSum<U>(curr, sigma2);
}
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize);
if (!rms_only) {
U muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize);
U countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum<U>(sigma2B, sigma2);
}
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U *ubuf = (U *)buf;
U *ibuf = (U *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
if (!rms_only) {
ubuf[2 * wrt_y] = mu;
ibuf[wrt_y] = count;
}
ubuf[2 * wrt_y + 1] = sigma2;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U sigma2B = ubuf[2 * threadIdx.y + 1];
if (!rms_only) {
U muB = ubuf[2 * threadIdx.y];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum<U>(sigma2B, sigma2);
}
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
if (!rms_only) {
ubuf[0] = mu;
}
ubuf[1] = sigma2;
}
__syncthreads();
if (!rms_only) {
mu = ubuf[0];
}
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
if (!rms_only) {
mu = __shfl_sync(0xffffffff, mu, 0, warpSize);
}
sigma2 = __shfl_sync(0xffffffff, sigma2 / U(n2), 0, warpSize);
}
}
}
template <>
__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, const int n1,
const int n2, const int i1, float &mu,
float &sigma2, float *buf, bool rms_only) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu = float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const __half *lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
if (!rms_only) {
cuWelfordOnlineSum(curr, mu, sigma2, count);
} else {
cuRMSOnlineSum(curr, sigma2);
}
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
if (!rms_only) {
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
} else {
cuRMSOnlineSum(curr.x, sigma2);
cuRMSOnlineSum(curr.y, sigma2);
}
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
if (!rms_only) {
cuWelfordOnlineSum(curr, mu, sigma2, count);
} else {
cuRMSOnlineSum(curr, sigma2);
}
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float sigma2B = __shfl_sync(0xffffffff, sigma2, srcLaneB, warpSize);
if (!rms_only) {
float muB = __shfl_sync(0xffffffff, mu, srcLaneB, warpSize);
float countB = __shfl_sync(0xffffffff, count, srcLaneB, warpSize);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum(sigma2B, sigma2);
}
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float *ubuf = (float *)buf;
float *ibuf = (float *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y + 1] = sigma2;
if (!rms_only) {
ubuf[2 * wrt_y] = mu;
ibuf[wrt_y] = count;
}
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float sigma2B = ubuf[2 * threadIdx.y + 1];
if (!rms_only) {
float muB = ubuf[2 * threadIdx.y];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum(sigma2B, sigma2);
}
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
if (!rms_only) {
ubuf[0] = mu;
}
ubuf[1] = sigma2;
}
__syncthreads();
if (!rms_only) {
mu = ubuf[0];
}
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
if (!rms_only) {
mu = __shfl_sync(0xffffffff, mu, 0, warpSize);
}
sigma2 = __shfl_sync(0xffffffff, sigma2 / float(n2), 0, warpSize);
}
}
}
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T> struct SharedMemory;
template <> struct SharedMemory<float> {
__device__ float *getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <> struct SharedMemory<double> {
__device__ double *getPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T, typename U, typename V>
__device__ void cuApplyLayerNorm_(V *__restrict__ output_vals,
U *__restrict__ mean, U *__restrict__ invvar,
const T *__restrict__ vals, const int n1,
const int n2, const U epsilon,
const V *__restrict__ gamma,
const V *__restrict__ beta, bool rms_only) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only);
const T *lvals = vals + i1 * n2;
V *ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && (beta != NULL || rms_only)) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
if (!rms_only) {
ovals[i] =
gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
} else {
ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);
}
}
} else {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
if (!rms_only) {
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
} else {
ovals[i] = static_cast<V>(c_invvar * curr);
}
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
if (!rms_only) {
mean[i1] = mu;
}
invvar[i1] = c_invvar;
}
__syncthreads();
}
}
template <typename T, typename U, typename V = T>
__global__ void
cuApplyRMSNorm(V *__restrict__ output_vals, U *__restrict__ invvar,
const T *__restrict__ vals, const int n1, const int n2,
const U epsilon, const V *__restrict__ gamma) {
cuApplyLayerNorm_<T, U, V>(output_vals, NULL, invvar, vals, n1, n2, epsilon,
gamma, NULL, true);
}
template <typename T, typename U, typename V = T>
void HostApplyRMSNorm(cudaStream_t stream, V *output, U *invvar, const T *input,
int n1, int n2, double epsilon, const V *gamma) {
auto getMaxGridY = []() {
int device;
int val;
cudaGetDevice(&device);
cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device);
return val;
};
const dim3 threads(32, 4, 1);
const uint64_t maxGridY = getMaxGridY();
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
output, invvar, input, n1, n2, U(epsilon), gamma);
}
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const V *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
if (!rms_only) {
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar;
}
} else {
if (!rms_only) {
warp_buf1[write_idx] = U(0);
}
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (!rms_only) {
warp_buf1[write_idx] = U(0);
}
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const V *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar, bool rms_only) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
if (!rms_only) {
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar;
}
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(
const V *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, U *part_grad_gamma, U *part_grad_beta, bool rms_only) {
const int numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U *warp_buf1 = (U *)buf;
U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar, rms_only);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar, rms_only);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
if (!rms_only) {
acc1 += warp_buf1[idx1];
}
acc2 += warp_buf2[idx1];
}
if (!rms_only) {
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
}
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
if (!rms_only) {
warp_buf1[idx1] += warp_buf1[idx2];
}
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
if (!rms_only) {
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
}
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void
cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
const int part_size, const int n1, const int n2,
V *grad_gamma, V *grad_beta, bool rms_only) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U *buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
if (!rms_only) {
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
if (!rms_only) {
buf[write_idx + nbsize3] = sum_beta;
}
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
if (!rms_only) {
sum_beta += buf[read_idx + nbsize3];
}
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
if (!rms_only) {
grad_beta[i2] = sum_beta;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void
cuComputeGradInput(const V *__restrict__ dout, const T *__restrict__ input,
const int n1, const int n2, const U *__restrict__ mean,
const U *__restrict__ invvar, U epsilon, const V *gamma,
T *grad_input, bool rms_only) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
U c_mean;
if (!rms_only) {
c_mean = mean[i1];
}
const U c_invvar = invvar[i1];
const T *k_input = input + i1 * n2;
const V *k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
if (!rms_only) {
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 += c_loss * static_cast<U>(gamma[l + k]) *
(c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * static_cast<U>(gamma[l + k]) * (c_h)*c_invvar;
}
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
if (!rms_only) {
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h)*c_invvar;
}
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h)*c_invvar;
}
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h)*c_invvar;
}
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
if (!rms_only) {
sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize);
}
sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
if (!rms_only) {
buf[2 * wrt_i] = sum_loss1;
}
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
if (!rms_only) {
sum_loss1 += buf[2 * read_i];
}
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
if (!rms_only) {
buf[2 * threadIdx.x] = sum_loss1;
}
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
if (!rms_only) {
sum_loss1 = buf[2 * threadIdx.x];
}
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
if (!rms_only) {
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
} else {
f_grad_input -= (c_h)*c_invvar * sum_loss2;
}
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
if (!rms_only) {
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
} else {
f_grad_input -= (c_h)*c_invvar * sum_loss2;
}
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
// prevent race where buf is written again before reads are done
__syncthreads();
}
}
template <typename T, typename U = float, typename V = T>
void HostRMSNormGradient(cudaStream_t stream, const V *dout, const U *invvar,
const T *input, int n1, int n2, const V *gamma,
double epsilon, T *grad_input, V *grad_gamma,
int part_size, U *part_grad_gamma) {
auto getMaxGridY = []() {
int device;
int val;
cudaGetDevice(&device);
cudaDeviceGetAttribute(&val, cudaDevAttrMaxGridDimY, device);
return val;
};
const uint64_t maxGridY = getMaxGridY();
if (gamma != NULL) {
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a =
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
// note (mkozuki): I can hard code part_grad_gamma's dtype as float given
// that the `cuda_layer_norm_gradient` doesn't support double.
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input, n1, n2,
invvar, // unused
invvar, U(epsilon), part_grad_gamma, part_grad_gamma, /* unused */
true);
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma, part_grad_gamma, /* unused */
part_size, n1, n2, grad_gamma, grad_gamma, /* unused */
true);
}
// compute grad_input
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input, n1, n2, invvar, /* unused */
invvar, U(epsilon), gamma, grad_input, true);
}
} // namespace
namespace gpu_ops {
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
const char *opaque,
std::size_t opaque_len) {
const RMSNormDescriptor &d =
*UnpackDescriptor<RMSNormDescriptor>(opaque, opaque_len);
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
d.x_type, d.w_type, "rms_norm_cuda_kernel",
HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(
stream, static_cast<scalar_t_out *>(buffers[2]),
static_cast<accscalar_t *>(buffers[3]),
static_cast<scalar_t_in *>(buffers[0]), d.n1, d.n2, d.eps,
/*gamma=*/static_cast<scalar_t_out *>(buffers[1]));)
}
void rms_backward_affine(cudaStream_t stream, void **buffers,
const char *opaque, std::size_t opaque_len) {
const RMSNormDescriptor &d =
*UnpackDescriptor<RMSNormDescriptor>(opaque, opaque_len);
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
d.x_type, d.w_type, "cuComputeGradInputRMS",
HostRMSNormGradient(
stream,
/*dout=*/static_cast<scalar_t_out *>(buffers[0]),
/*invvar=*/static_cast<accscalar_t *>(buffers[1]),
/*input=*/static_cast<scalar_t_in *>(buffers[2]), d.n1, d.n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
/*gamma=*/static_cast<scalar_t_out *>(buffers[3]), d.eps,
/*grad_input=*/static_cast<scalar_t_in *>(buffers[4]),
/*grad_gamma=*/static_cast<scalar_t_out *>(buffers[5]),
d.part_grad_size,
/*part_grad_gamma=*/static_cast<accscalar_t *>(buffers[6]));)
}
} // namespace gpu_ops

View File

@ -69,7 +69,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
The default value is False.
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
i-th layer computation. It also enables enable overlapping (i+1)-th layer
i-th layer computation. It also enables overlapping (i+1)-th layer
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
value is False. **There are some bugs when this flag is turned on.**
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -326,7 +326,7 @@
"id": "iuHqht-OYqca"
},
"source": [
"The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`."
"The outputs of the inner `jax.pmap(convolve)` have never left their devices when being fed into the outer `jax.pmap(convolve)`."
]
},
{

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
@ -125,7 +125,7 @@ jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
+++ {"id": "iuHqht-OYqca"}
The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`.
The outputs of the inner `jax.pmap(convolve)` have never left their devices when being fed into the outer `jax.pmap(convolve)`.
+++ {"id": "vEFAJXN2q3dV"}

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -920,7 +920,7 @@
"id": "ORMVVGZJgSVi"
},
"source": [
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
]
},
{
@ -1946,9 +1946,9 @@
"\n",
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
"\n",
"* adding `from jax import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"\n",
"* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
"\n",
@ -2135,30 +2135,30 @@
"\n",
"There are a few ways to do this:\n",
"\n",
"1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
"1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
"\n",
"2. You can manually set the `jax_enable_x64` configuration flag at startup:\n",
"\n",
" ```python\n",
" # again, this only works on startup!\n",
" from jax import config\n",
" config.update(\"jax_enable_x64\", True)\n",
" import jax\n",
" jax.config.update(\"jax_enable_x64\", True)\n",
" ```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
" ```python\n",
" from jax import config\n",
" config.config_with_absl()\n",
" import jax\n",
" jax.config.config_with_absl()\n",
" ```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
" ```python\n",
" from jax import config\n",
" import jax\n",
" if __name__ == '__main__':\n",
" # calls config.config_with_absl() *and* runs absl parsing\n",
" config.parse_flags_with_absl()\n",
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
" jax.config.parse_flags_with_absl()\n",
" ```\n",
"\n",
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python
@ -407,7 +407,7 @@ print(np.random.random())
+++ {"id": "ORMVVGZJgSVi"}
Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
```{code-cell} ipython3
:id: 7Pyp2ajzfPO2
@ -938,9 +938,9 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file;
* adding `from jax import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.
@ -1081,30 +1081,30 @@ To use double-precision numbers, you need to set the `jax_enable_x64` configurat
There are a few ways to do this:
1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.
1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.
2. You can manually set the `jax_enable_x64` configuration flag at startup:
```python
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
import jax
jax.config.update("jax_enable_x64", True)
```
3. You can parse command-line flags with `absl.app.run(main)`
```python
from jax import config
config.config_with_absl()
import jax
jax.config.config_with_absl()
```
4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use
```python
from jax import config
import jax
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
# calls jax.config.config_with_absl() *and* runs absl parsing
jax.config.parse_flags_with_absl()
```
Note that #2-#4 work for _any_ of JAX's configuration options.

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -17,11 +17,7 @@
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"\n",
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.\n",
"\n",
"Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`.\n",
"\n",
"**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs."
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
]
},
{
@ -2369,6 +2365,7 @@
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"provenance": [],
"toc_visible": true

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
@ -21,10 +21,6 @@ kernelspec:
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.
Refer to the [`jax.Array migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide to learn how to migrate the existing JAX pre-v0.4.1 codebases to `jax.Array`.
**Note:** The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Google Cloud TPU and Kaggle TPU VMs.
```{code-cell}
:id: FNxScTfq3vGF

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -7,7 +7,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python

View File

@ -147,8 +147,10 @@ grid axes over cores. This is an opt-in procedure. To allow that,
..
pallas_call(
...,
mosaic_params=dict(
dimension_semantics=["parallel", "parallel", "arbitrary"]
compiler_params=dict(
mosaic=dict(
dimension_semantics=["parallel", "parallel", "arbitrary"]
)
),
)

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3

View File

@ -0,0 +1,74 @@
# Persistent Compilation Cache
JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.
## Usage
The compilation cache is enabled when the
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
```
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
```
See the sections below for more detail on `cache-location`.
[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
is an alternate way of setting `cache-location`.
### Local filesystem
`cache-location` can be a directory on the local filesystem. For example:
```
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
```
Note: the cache does not have an eviction mechanism implemented. If the
cache-location is a directory in the local filesystem, its size will continue
to grow unless files are manually deleted.
### Google Cloud
When running on Google Cloud, the compilation cache can be placed on a Google
Cloud Storage (GCS) bucket. We recommend the following configuration:
* Create the bucket in the same region as where the workload will run.
* Create the bucket in the same project as the workloads VM(s). Ensure that
permissions are set so that the VM(s) can write to the bucket.
* There is no need for replication for smaller workloads. Larger workloads
could benefit from replication.
* Use “Standard” for the default storage class for the bucket.
* Set the soft delete policy to its shortest: 7 days.
* Set the object lifecycle to the expected duration of the workload run.
For example, if the workload is expected to run for 10 days, set the object
lifecycle to 10 days. That should cover restarts that occur during the entire
run. Use `age` for the lifecycle condition and `Delete` for the action. See
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
for details. If the object lifecycle is not set, the cache will continue to
grow since there is no eviction mechanism implemented.
* All encryption policies are supported.
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:
```
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```

View File

@ -40,8 +40,8 @@ One is by using :code:`jax.config` in your code:
.. code-block:: python
from jax import config
config.update("jax_numpy_rank_promotion", "warn")
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.0
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python

View File

@ -15,7 +15,7 @@ or deployed codebases.
device_memory_profiling
debugging/index
gpu_performance_tips
persistent_compilation_cache
.. toctree::
:maxdepth: 1

View File

@ -22,6 +22,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
@ -30,8 +31,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from examples import kernel_lsq
sys.path.pop()
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):

View File

@ -17,10 +17,11 @@
from absl import app
from functools import partial
import jax
from jax import grad
from jax import jit
from jax import vmap
from jax import config
import jax.numpy as jnp
import jax.random as random
import jax.scipy as scipy
@ -125,5 +126,5 @@ def main(unused_argv):
mu.flatten() - std * 2, mu.flatten() + std * 2)
if __name__ == "__main__":
config.config_with_absl()
jax.config.config_with_absl()
app.run(main)

View File

@ -201,6 +201,7 @@ py_library_providing_imports_info(
"_src/debugging.py",
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",
@ -997,7 +998,11 @@ pytype_library(
pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
srcs = [
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
"experimental/host_callback.py",
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
],
visibility = ["//visibility:public"],
deps = [
":jax",

View File

@ -125,6 +125,7 @@ from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation
from jax._src.sharding_impls import NamedSharding as NamedSharding
# Force import, allowing jax.interpreters.* to be used after import jax.
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla

View File

@ -2559,7 +2559,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
# TODO(jakevdp): provide a default for devices that considers both local
# devices and pods
if not isinstance(shards, Sequence):
raise ValueError("device_put_sharded `shards` input must be a sequence; "
raise TypeError("device_put_sharded `shards` input must be a sequence; "
f"got {type(shards)}")
if len(shards) != len(devices):
raise ValueError(f"len(shards) = {len(shards)} must equal "
@ -2911,7 +2911,7 @@ def named_scope(
... return jax.nn.relu(logits)
"""
if not isinstance(name, str):
raise ValueError("named_scope name argument must be a string.")
raise TypeError("named_scope name argument must be a string.")
with source_info_util.extend_name_stack(name):
yield

View File

@ -30,11 +30,9 @@ from jax._src import api_util
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import layout
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
@ -47,10 +45,10 @@ from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.typing import ArrayLike
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
deprecations.register(__name__, "device-method")
Shape = tuple[int, ...]
Device = xc.Device
@ -406,11 +404,25 @@ class ArrayImpl(basearray.Array):
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)
def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise BufferError("__dlpack__ only supported for unsharded arrays.")
def __dlpack__(self, *, stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
copy: bool | None = None):
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)
device_set = self.sharding.device_set
if len(device_set) > 1:
raise BufferError(
"to_dlpack can only pack a dlpack tensor from an array on a singular "
f"device, but an array with a Sharding over {len(device_set)} devices "
"was provided."
)
device, = device_set
return to_dlpack(self, stream=stream,
max_version=max_version,
src_device=device,
dl_device=dl_device,
copy=copy)
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
@ -471,21 +483,6 @@ class ArrayImpl(basearray.Array):
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
return per_shard_size * len(self.sharding.device_set)
# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
if deprecations.is_accelerated(__name__, "device-method"):
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
else:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
single_device, = device_set
return single_device
raise ValueError('Length of devices is greater than 1. '
'Please use `.devices()`.')
def devices(self) -> set[Device]:
self._check_if_deleted()
return self.sharding.device_set
@ -531,13 +528,15 @@ class ArrayImpl(basearray.Array):
@property
def layout(self):
# TODO(yashkatariya): Remove the deleted check from here.
if self.is_deleted():
return Layout(None, self.sharding)
try:
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
self.sharding)
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return layout.Layout(None, self.sharding)
return Layout(None, self.sharding)
else:
raise

View File

@ -196,7 +196,6 @@ class Array(abc.ABC):
def block_until_ready(self) -> Array: ...
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> set[Device]: ...
@property
def sharding(self) -> Sharding: ...

View File

@ -14,14 +14,12 @@
"""Module for JAX callbacks."""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
import logging
import dataclasses
import functools
import logging
from typing import Any, Callable
import numpy as np
import jax
from jax._src import core
from jax._src import dispatch
@ -33,9 +31,10 @@ from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
import numpy as np
logger = logging.getLogger(__name__)
@ -73,11 +72,14 @@ def pure_callback_impl(
vectorized: bool,
):
del sharding, vectorized, result_avals
try:
return callback(*args)
except BaseException:
logger.exception("jax.pure_callback failed")
raise
cpu_device, *_ = jax.local_devices(backend="cpu")
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
with jax.default_device(cpu_device):
try:
return tree_util.tree_map(np.asarray, callback(*args))
except BaseException:
logger.exception("jax.pure_callback failed")
raise
pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
@ -398,11 +400,14 @@ def io_callback_impl(
ordered: bool,
):
del result_avals, sharding, ordered
try:
return callback(*args)
except BaseException:
logger.exception("jax.io_callback failed")
raise
cpu_device, *_ = jax.local_devices(backend="cpu")
args = tree_util.tree_map(lambda arg: jax.device_put(arg, cpu_device), args)
with jax.default_device(cpu_device):
try:
return tree_util.tree_map(np.asarray, callback(*args))
except BaseException:
logger.exception("jax.io_callback failed")
raise
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
@ -439,16 +444,16 @@ def io_callback_batching_rule(
):
if ordered:
raise ValueError("Cannot `vmap` ordered IO callback.")
return pure_callback_batching_rule(
args,
dims,
callback=callback,
sharding=sharding,
vectorized=False,
result_avals=result_avals,
)
is_batched = [d is not batching.not_mapped for d in dims]
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged = util.merge_lists(is_batched, unbatched_args, batched_args)
return io_callback_p.bind(*merged, callback=callback, sharding=sharding,
result_avals=result_avals, ordered=False)
out_vals = lax_map(_batch_fun, batched_args)
return out_vals, (0,) * len(out_vals)
batching.primitive_batchers[io_callback_p] = io_callback_batching_rule

View File

@ -895,9 +895,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
inline, keep_unused):
in_shardings, out_shardings,
in_layouts, out_layouts,
resource_env, donated_invars, name, inline, keep_unused):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
@ -908,10 +908,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# Update pjit params to account for extra error values.
num_error_vals = len(err_vals)
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
sharding = sharding_impls.UNSPECIFIED
sharding = sharding_impls.UNSPECIFIED
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
new_in_layouts = (*[None] * num_error_vals, *in_layouts)
new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
err_and_out = pjit.pjit_p.bind(
@ -919,6 +921,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
in_layouts=new_in_layouts,
out_layouts=new_out_layouts,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,
@ -1296,6 +1300,6 @@ def check_error(error: Error) -> None:
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
raise TypeError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
_check_error(error, debug=False)

View File

@ -628,7 +628,7 @@ def define_string_state(
def validator(new_val):
if not isinstance(new_val, str):
raise ValueError('new string config value must be of type str,'
raise TypeError('new string config value must be of type str,'
f' got {new_val} of type {type(new_val)}.')
return define_string_or_object_state(
@ -1390,6 +1390,13 @@ eager_pmap = define_bool_state(
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
custom_vjp_disable_shape_check = define_bool_state(
name='jax_custom_vjp_disable_shape_check',
default=False,
upgrade=True,
help='Disable the check from #19009 to enable some custom_vjp hacks.')
xla_runtime_errors = define_bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,

View File

@ -54,6 +54,7 @@ from jax._src.lib import jax_jit
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src import typing
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
@ -832,14 +833,14 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
@property
def block_until_ready(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'block_until_ready' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def copy_to_host_async(self):
# Raise AttribureError for backward compatibility with hasattr() and getattr() checks.
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@ -849,11 +850,6 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
f"The delete() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def device(self):
raise ConcretizationTypeError(self,
f"The device() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on {self._error_repr()}."
@ -910,10 +906,20 @@ class EvalTrace(Trace):
lift = sublift = pure
def process_primitive(self, primitive, tracers, params):
return primitive.impl(*tracers, **params)
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
else:
return primitive.impl(*tracers, **params)
def process_call(self, primitive, f, tracers, params):
return primitive.impl(f, *tracers, **params)
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
else:
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_transpose(self, primitive, call, tracers, **_):

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from functools import partial, reduce
import operator
from typing import Optional
@ -36,6 +37,17 @@ Array = jnp.ndarray
DType = jnp.dtype
PRNGKey = jnp.ndarray
class AttentionLayout(Enum):
BTNH = 0
BNTH = 1
def _normalize_layout(layout: str) -> AttentionLayout:
layout_upper = layout.upper()
if layout_upper in ['BSNH', 'BNSH', 'BTNH', 'BNTH']:
return AttentionLayout[layout_upper.replace('S', 'T')]
else:
raise ValueError(f"Unsupported qkv_layout: {layout}")
def element_type_to_backend_config_type_mapping(dtype):
_element_type_to_backend_config_type_mapping = {
ir.BF16Type.get(): "BF16",
@ -56,18 +68,18 @@ def create_dot_product_attention_backend_config(batch,
dropout_rate,
is_flash_attention,
is_causal_mask,
layout,
is_bwd):
# b q_seq num_heads head_dim -> Q
# b kv_seq num_heads head_dim -> K
# b kv_seq num_heads head_dim -> V
# b num_heads q_seq kv_seq -> P
# b q_seq num_heads head_dim -> O
# bmm1: Q @ K -> P
# bmm2: P @ V -> O
# bmm2Grad1: P @ dO -> dV
# bmm2Grad2: dO @ V -> dP
# bmm1Grad1: dP @ Q -> dK
# bmm1Grad2: dP @ K -> dQ
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
# P: BMM1 output in shape of BNTS
# O: BMM2 output in the same shape with Q
# BMM1: Q @ K -> P
# BMM2: P @ V -> O
# BMM1Grad1: dP @ Q -> dK
# BMM1Grad2: dP @ K -> dQ
# BMM2Grad1: P @ dO -> dV
# BMM2Grad2: dO @ V -> dP
cudnn_fmha_backend_config = {
"algorithm": {
"algo_id": "0",
@ -100,46 +112,47 @@ def create_dot_product_attention_backend_config(batch,
"is_flash_attention": is_flash_attention,
"is_causal_mask": is_causal_mask,
}
fwd_dot_number = {
"bmm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "2"],
"rhs_batch_dimensions": ["0", "2"],
},
"bmm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["1"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "2"],
},
}
bwd_dot_number = {
"bmm1_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["1"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "2"],
},
"bmm1_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["1"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "2"],
},
"bmm2_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["1"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "2"],
},
"bmm2_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "2"],
"rhs_batch_dimensions": ["0", "2"],
},
}
# We define the contracting and batch dims in the format of
# ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
# rhs_batch_dims)).
if layout == AttentionLayout.BNTH.value:
dims = [
((3, 3), ((0, 1), (0, 1))), # BMM1: BNTH,BNSH->BNTS
((3, 2), ((0, 1), (0, 1))), # BMM2: BNTS,BNSH->BNTH
((2, 2), ((0, 1), (0, 1))), # BMM1_grad_1: BNTS,BNTH->BNSH
((3, 2), ((0, 1), (0, 1))), # BMM1_grad_2: BNTS,BNSH->BNTH
((2, 2), ((0, 1), (0, 1))), # BMM2_grad_1: BNTS,BNTH->BNSH
((3, 3), ((0, 1), (0, 1))), # BMM2_grad_2: BNTH,BNSH->BNTS
]
else:
dims = [
((3, 3), ((0, 2), (0, 2))), # BMM1: BTNH,BSNH->BNTS
((3, 1), ((0, 1), (0, 2))), # BMM2: BNTS,BSNH->BTNH
((2, 1), ((0, 1), (0, 2))), # BMM1_grad_1: BNTS,BTNH->BSNH
((3, 1), ((0, 1), (0, 2))), # BMM1_grad_2: BNTS,BSNH->BTNH
((2, 1), ((0, 1), (0, 2))), # BMM2_grad_1: BNTS,BTNH->BSNH
((3, 3), ((0, 2), (0, 2))), # BMM2_grad_2: BTNH,BSNH->BNTS
]
keys = [
"bmm1_dot_dimension_numbers",
"bmm2_dot_dimension_numbers",
"bmm1_grad_gemm1_dot_dimension_numbers",
"bmm1_grad_gemm2_dot_dimension_numbers",
"bmm2_grad_gemm1_dot_dimension_numbers",
"bmm2_grad_gemm2_dot_dimension_numbers",
]
fwd_dot_number = {}
bwd_dot_number = {}
for idx, (key, ((lc, rc), (lb, rb))) in enumerate(zip(keys, dims)):
dims_to_write = fwd_dot_number if idx < 2 else bwd_dot_number
dims_to_write[key] = {
"lhs_contracting_dimensions": [str(lc)],
"rhs_contracting_dimensions": [str(rc)],
"lhs_batch_dimensions": [str(i) for i in lb],
"rhs_batch_dimensions": [str(i) for i in rb],
}
if is_bwd:
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number}
else:
@ -178,46 +191,59 @@ _custom_name_maps = {
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
return _custom_name_maps[(is_bwd, has_dropout, has_mask, has_bias)]
def check_qkv_layout(query, key, value):
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
"query, key and value should have rank 4."
def check_qkv_layout(query, key, value, layout):
def check_eq(a, b, c, msg):
if not (a == b == c):
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
# Only support fp16 and bf16 here
query_dtype = query.dtype
key_dtype = key.dtype
value_dtype = value.dtype
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
"query, key and value should have same dtype and should be float16 or bfloat16"
q_rank, k_rank, v_rank = len(query.shape), len(key.shape), len(value.shape)
if q_rank != 4:
raise ValueError(f"Q must have a rank of 4, got {q_rank}")
check_eq(q_rank, k_rank, v_rank, 'QKV rank')
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
if not((q_batch == k_batch == v_batch)
and (k_seq_len == v_seq_len)
and (q_num_heads == k_num_heads == v_num_heads)
and (q_head_dim == k_head_dim == v_head_dim)):
raise ValueError(
"query should have layout [batch, q_seq, num_heads, head_dim], " \
"key and value should have layout [batch, kv_seq, num_heads, head_dim].")
q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype
assert q_dtype in [jnp.float16, jnp.bfloat16], "Q must be fp16 or bf16"
check_eq(q_dtype, k_dtype, v_dtype, 'QKV dtype')
def check_is_flash_attention(query, key, cudnn_version, has_bias, is_training):
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_seq_len, _, _ = key.shape
if layout == AttentionLayout.BNTH:
qB, qN, _, qH = query.shape
kB, kN, kS, kH = key.shape
vB, vN, vS, vH = value.shape
else:
assert layout == AttentionLayout.BTNH
qB, _, qN, qH = query.shape
kB, kS, kN, kH = key.shape
vB, vS, vN, vH = value.shape
# check if attention pattern is supported by flash attention or fused attention
if q_seq_len <= 512 and kv_seq_len <= 512 and head_dim == 64 \
and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0):
check_eq(qB, kB, vB, 'QKV batch')
check_eq(qN, kN, vN, 'QKV num_head')
check_eq(qH, kH, vH, 'QKV dim_per_head')
if kS != vS:
raise ValueError(f'KV must have same seq length, got {kS} vs {vS}')
def check_is_flash_attention(
query, key, layout, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH:
_, N, T, H = query.shape
_, _, S, _ = key.shape
else:
_, T, N, H = query.shape
_, S, _, _ = key.shape
# check if attention pattern is supported by flash attention or fused attention.
if ((T <= 512 and S <= 512 and H == 64) and
(not is_training or T % 64 == 0 and S % 64 == 0)):
# check if regular fused attention is supported
# for training, seqlen should be divisible by 64
is_flash_attention = False
elif head_dim <= 128 and head_dim % 8 == 0 \
and (not is_training or not has_bias or q_seq_len % 2 == 0 and kv_seq_len % 2 == 0):
elif ((H <= 128 and H % 8 == 0) and
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
# check if flash attention is supported
# for training, for patterns with bias, seqlen should be divisible by 2
is_flash_attention = True
else:
raise NotImplementedError(
f"Unsupported sequence length Q {q_seq_len}, KV {kv_seq_len} and head dim {head_dim}.")
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}.")
# check if minimum cudnn version requirement is satisfied
if is_flash_attention and cudnn_version < 8904:
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
@ -232,61 +258,78 @@ def check_cudnn_version():
raise RuntimeError("cuDNN is not detected.")
return cuda_versions.cudnn_get_version()
def _dot_product_attention_fwd(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd(
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout, is_training):
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, is_training=is_training)
query, key, value, bias, mask, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
output = outputs[0]
return output
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd_rule(
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout, is_training):
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, is_training=is_training)
query, key, value, bias, mask, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
res = (query, key, value, bias, mask, outputs[1], outputs[0]) if is_training else None
return outputs[0], res
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, res, grad_output):
def _dot_product_attention_bwd_rule(
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout, is_training, res, grad_output):
query, key, value, bias, mask, activation, fwd_output = res
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
query, key, value, bias, mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, layout=layout,
)
grads = (grad_query, grad_key, grad_value, None, None)
return grads
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd_impl(
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout, is_training):
# args: {Q, K, V, mask*, bias*}
outputs = _dot_product_attention_fwd_p.bind(
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, is_training=is_training)
query, key, value, bias, mask, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
return outputs
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
def _dot_product_attention_bwd_impl(
query, key, value, bias, mask, activation, fwd_output, grad_output, scale,
seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask,
layout):
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind(
query, key, value, bias, mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
query, key, value, bias, mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, layout=layout,
)
grads = (grad_query, grad_key, grad_value)
return grads
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd_abstract(
query, key, value, bias, mask, *, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout, is_training):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_seq_len, _, _ = key.shape
output_shape = (batch, q_seq_len, num_heads, head_dim)
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
softmax_stat_shape = (batch, num_heads, q_seq_len)
if layout == AttentionLayout.BNTH.value:
B, N, T, _ = query.shape
_, _, S, _ = key.shape
else:
B, T, N, _ = query.shape
_, S, _, _ = key.shape
output_shape = query.shape
activation_shape = (B, N, T, S)
softmax_stat_shape = (B, N, T)
if is_flash_attention:
# is flash attention
@ -309,8 +352,10 @@ def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
core.ShapedArray(output_shape, query_dtype), # output
)
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
def _dot_product_attention_bwd_abstract(
query, key, value, bias, mask, activation, fwd_output, grad_output, *,
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
@ -327,8 +372,9 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
), # part value
)
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd_cuda_lowering(
ctx, query, key, value, bias, mask, scale, seed, dropout_rate,
variadic_args, is_flash_attention, is_causal_mask, layout, is_training):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -336,18 +382,26 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
value_type = ir.RankedTensorType(value.type)
value_shape = value_type.shape
batch, q_seq_len, num_heads, head_dim = query_shape
_, kv_seq_len, _, _ = key_shape
if layout == AttentionLayout.BNTH.value:
B, N, T, H = query_shape
_, _, S, _ = key_shape
output_layout = (3, 2, 1, 0)
output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
else:
B, T, N, H = query_shape
_, S, _, _ = key_shape
output_layout = (3, 1, 2, 0)
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
output_shape = (batch, num_heads, q_seq_len, head_dim)
output_layout = (3, 1, 2, 0)
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
softmax_stat_shape = (batch, num_heads, q_seq_len)
output_shape = (B, N, T, H)
activation_shape = (B, N, T, S)
softmax_stat_shape = (B, N, T)
scratch_shape = (0,)
scratch_type = ir.IntegerType.get_unsigned(8)
# get backend config
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False)
backend_config = create_dot_product_attention_backend_config(
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
is_flash_attention, is_causal_mask, layout, is_bwd=False,
)
# {Q, K, V, mask*, bias*}
# {output, scratch, activation*}
has_dropout = dropout_rate > 0
@ -403,8 +457,10 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
else:
return [hlo.transpose(out.results[0], output_transpose_perm)]
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
def _dot_product_attention_bwd_cuda_lowering(
ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -416,18 +472,28 @@ def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask,
grad_output_type = ir.RankedTensorType(grad_output.type)
grad_output_shape = grad_output_type.shape
batch, q_seq_len, num_heads, head_dim = query_shape
_, kv_seq_len, _, _ = key_shape
if layout == AttentionLayout.BNTH.value:
B, N, T, H = query_shape
_, _, S, _ = key_shape
grad_layout = (3, 2, 1, 0)
grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
else:
B, T, N, H = query_shape
_, S, _, _ = key_shape
grad_layout = (3, 1, 2, 0)
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
scratch_shape = (0,)
scratch_type = ir.IntegerType.get_unsigned(8)
grad_query_shape = (batch, num_heads, q_seq_len, head_dim)
grad_key_shape = (batch, num_heads, kv_seq_len, head_dim)
grad_value_shape = (batch, num_heads, kv_seq_len, head_dim)
softmax_sum_shape = (batch, num_heads, q_seq_len)
grad_layout = (3, 1, 2, 0)
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True)
grad_query_shape = (B, N, T, H)
grad_key_shape = (B, N, S, H)
grad_value_shape = (B, N, S, H)
softmax_sum_shape = (B, N, T)
backend_config = create_dot_product_attention_backend_config(
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
is_flash_attention, is_causal_mask, layout, is_bwd=True,
)
# {Q, K, V, activation, dO, mask*, bias*, O*}
# {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*}
has_dropout = dropout_rate > 0
@ -484,7 +550,9 @@ def _check_valid_batch_dims(bdims):
raise NotImplementedError("Currently only support batch_dim in [0, None], " \
f"but got {dim=}")
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
def _dot_product_attention_fwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout, is_training):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, mask = batched_args
query_bdim = batch_dims[0]
@ -493,74 +561,84 @@ def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed,
else:
out_bdims = (query_bdim,)
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
*_, kv_seq_len, _, _ = key.shape
new_batch = reduce(operator.mul, batch_tuple)
if layout == AttentionLayout.BNTH.value:
*Bs, N, T, _ = query.shape
*_, _, S, _ = key.shape
else:
*Bs, T, N, _ = query.shape
*_, S, _, _ = key.shape
B = reduce(operator.mul, Bs)
has_bias, has_mask = variadic_args
# reshape to 4D shape
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias:
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
bias = jnp.reshape(bias, (B, N, T, S))
if has_mask:
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
mask = jnp.reshape(mask, (B, N, T, S))
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, mask,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, is_training=is_training)
query, key, value, bias, mask, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
# reshape to original shape
output = outputs[0]
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
output = jnp.reshape(output, query.shape)
if is_training:
activation = outputs[1]
if is_flash_attention:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
activation = jnp.reshape(activation, (*Bs, N, T))
else:
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
activation = jnp.reshape(activation, (*Bs, N, T, S))
return (output, activation), out_bdims
else:
return (output,), out_bdims
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
def _dot_product_attention_bwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args
query_bdim = batch_dims[0]
out_bdims = query_bdim, query_bdim, query_bdim
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
*_, kv_seq_len, _, _ = key.shape
new_batch = reduce(operator.mul, batch_tuple)
if layout == AttentionLayout.BNTH.value:
*Bs, N, T, _ = query.shape
*_, _, S, _ = key.shape
else:
*Bs, T, N, _ = query.shape
*_, S, _, _ = key.shape
B = reduce(operator.mul, Bs)
has_bias, has_mask = variadic_args
# reshape to 4D shape
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias:
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
bias = jnp.reshape(bias, (B, N, T, S))
if has_mask:
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
mask = jnp.reshape(mask, (B, N, T, S))
if is_flash_attention:
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len))
activation = jnp.reshape(activation, (B, N, T))
else:
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len, kv_seq_len))
fwd_output = jnp.reshape(fwd_output, (new_batch, q_seq_len, num_heads, head_dim))
grad_output = jnp.reshape(grad_output, (new_batch, q_seq_len, num_heads, head_dim))
activation = jnp.reshape(activation, (B, N, T, S))
fwd_output = jnp.reshape(fwd_output, (B,) + query.shape[-3:])
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias,
mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask)
query, key, value, bias, mask, activation, fwd_output, grad_output,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
is_causal_mask=is_causal_mask, layout=layout,
)
# reshape to original shape
grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim))
grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim))
grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim))
grad_query = jnp.reshape(grad_query, query.shape)
grad_key = jnp.reshape(grad_key, key.shape)
grad_value = jnp.reshape(grad_value, value.shape)
grads = (grad_query, grad_key, grad_value)
return grads, out_bdims
@ -617,17 +695,25 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training):
return [out_sharding, activation_sharding]
return [out_sharding]
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10,11))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
_dot_product_attention_fwd_lower = custom_partitioning(
_dot_product_attention_fwd_impl, static_argnums=(5, 6, 7, 8, 9, 10, 11, 12))
def _dot_product_attention_fwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout, is_training, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
def _dot_product_attention_fwd_partition(
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout, is_training, mesh, arg_shapes, result_shape):
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
is_training=is_training)
impl = partial(
_dot_product_attention_fwd_impl, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
# bwd custom partition
@ -648,16 +734,27 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args):
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
return out_shardings
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
_dot_product_attention_bwd_lower = custom_partitioning(
_dot_product_attention_bwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14)
)
def _dot_product_attention_bwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout, mesh, arg_shapes, result_shape):
return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
def _dot_product_attention_bwd_partition(
scale, seed, dropout_rate, variadic_args, is_flash_attention,
is_causal_mask, layout, mesh, arg_shapes, result_shape):
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
impl = partial(
_dot_product_attention_bwd_impl, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout,
)
return mesh, impl, out_shardings, arg_shardings
# Create dot_product_attention_fwd_p for forward operation.
@ -717,24 +814,25 @@ dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Array,
mask: Array,
scale: float,
seed: int,
dropout_rate: float,
variadic_args: tuple[bool, ...],
is_flash_attention: bool,
is_causal_mask: bool,
is_training: bool):
key: Array,
value: Array,
bias: Array,
mask: Array,
scale: float,
seed: int,
dropout_rate: float,
variadic_args: tuple[bool, ...],
is_flash_attention: bool,
is_causal_mask: bool,
layout: int,
is_training: bool):
output = _dot_product_attention_fwd(
query, key, value, bias, mask,
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
is_training=is_training)
query, key, value, bias, mask, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
layout=layout, is_training=is_training)
return output
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
@ -751,40 +849,51 @@ def dot_product_attention(query: Array,
is_causal_mask: bool = False,
seed: int = 42,
dropout_rate: float = 0.,
qkv_layout: str = 'BTNH',
is_training = False):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
batch seq num_heads, head_dim // but all assume Q, K and V will have same
b q_seq num_heads head_dim -> Q
b kv_seq num_heads head_dim -> K
b kv_seq num_heads head_dim -> V
"""Computes dot-product attention given query (Q), key (K), and value (V).
This function serves as the core operation for applying attention
mechanisms as described in the paper [https://arxiv.org/abs/1706.03762].
Initially, it determines the attention weights by processing Q and K,
subsequently combining the outcomes using K. Throughout this function, we
utilize the following uppercase letters to represent specific parameters of
array:
B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head.
The supported layouts for Q, K, V are either BT(S)NH or BNT(S)H, and they must
adhere to the same layout. The output layout remains consistent with Q,
defaulting to BT(S)NH.
Args:
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`.
bias: bias to be added to logits with shape of `[batch, num_heads,
q_length, kv_length]`.
mask: mask used mask out logits with shape of `[batch, num_heads,
q_length, kv_length]`.
scale: scale for the query.
is_causal_mask: choose to apply a causal mask or not.
seed: used for dropout mask generation.
dropout_rate: dropout rate.
query: Queries for attention calculation with a shape of BTNH or BNTH.
key: Keys for attention calculation with a shape of BSNH or BNSH.
value: Values to be used in attention with a shape of BSNH or BNSH.
bias: Bias to be added to logits with a shape of BNTS.
mask: Mask used to filter out logits with a shape of BNTS.
scale: Scale for the query.
dropout_rate: Dropout rate.
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
BNSH.
is_training: choose to save activation or not.
Returns:
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
Output of the same shape as the query.
"""
# check if cuDNN is installed
cudnn_version = check_cudnn_version()
layout = _normalize_layout(qkv_layout)
# check query, key and value shape and data type
check_qkv_layout(query, key, value)
check_qkv_layout(query, key, value, layout)
# check if flash attention is supported for this attention pattern
is_flash_attention = check_is_flash_attention(query, key, cudnn_version, bias is not None, is_training)
is_flash_attention = check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, is_training)
if mask is not None and is_causal_mask:
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
if not is_flash_attention and is_causal_mask:
@ -795,7 +904,7 @@ def dot_product_attention(query: Array,
if mask is None:
mask = jnp.zeros(0, dtype=query.dtype)
output = _dot_product_attention(
query, key, value, bias, mask,
scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, is_training)
query, key, value, bias, mask, scale, seed, dropout_rate, variadic_args,
is_flash_attention, is_causal_mask, layout.value, is_training
)
return output

View File

@ -772,7 +772,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
results.append(Zero(ct.aval))
else:
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
and not _temporary_dtype_exception(a, a_)):
and not (_temporary_dtype_exception(a, a_) or
_temporary_shape_exception(a, a_))):
msg = ("Custom VJP bwd rule must produce an output with the same "
"shape/dtypes as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
@ -790,6 +791,9 @@ def _temporary_dtype_exception(a, a_) -> bool:
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
return False
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
def _temporary_shape_exception(a, a_) -> bool:
return config.custom_vjp_disable_shape_check.value
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive

View File

@ -108,7 +108,7 @@ def simple_impl(prim):
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
"""See docstring for effect.py module for the calling convention for tokens."""
"""See docstring for effects.py module for the calling convention for tokens."""
# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
@ -125,6 +125,16 @@ class RuntimeTokenSet(threading.local):
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
if isinstance(tok, jax.Array):
# The order of devices may change, so we need to reshard if necessary.
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
# scenario. Revise the logic later. A distributed shutdown barrier inside
# the XLA program may be needed.
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = pxla.shard_args([s], [tok])[0]
self.current_tokens[eff] = sharded_tok
@ -452,10 +462,7 @@ def _device_put_impl(
return x
if x_dll is None and dll is None:
return _device_put_sharding_impl(x, aval, l.sharding)
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
# out_layouts from lower.
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
x, _out_layouts=l).compile()(x)
return api.jit(_identity_fn, out_shardings=l)(x)
return _device_put_sharding_impl(x, aval, device)

View File

@ -14,17 +14,20 @@
from __future__ import annotations
import enum
from typing import Any
import warnings
from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lax.lax import _array_copy
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding
DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)
# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@ -42,56 +45,218 @@ if xla_extension_version >= 231:
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
def _to_dlpack(x: Array, stream: int | Any | None,
src_device: xla_client.Device | None = None,
device: xla_client.Device | None = None,
copy: bool | None = None):
if src_device is None:
src_device, = x.devices()
if device and (src_device is None or device != src_device):
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(src_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
arr = device_put(x, device)
else:
arr = _array_copy(x) if copy else x
return xla_client._xla.buffer_to_dlpack_managed_tensor(
arr.addressable_data(0), stream=stream
)
def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
def to_dlpack(x: Array, stream: int | Any | None = None,
src_device: xla_client.Device | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
max_version: tuple[int, int] | None = None,
copy : bool | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Args:
x: a :class:`~jax.Array`, on either CPU or GPU.
take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
deleted in 01/2024.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
src_device: either a CPU or GPU :class:`~jax.Device`.
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
format e.g. as produced by ``__dlpack_device__``.
max_version: the maximum DLPack version that the consumer (i.e. caller of
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
This function is not guaranteed to return a capsule of version
``max_version``.
copy: a boolean indicating whether or not to copy the input. If
``copy=True`` then the function must always copy. When
``copy=False`` then the function must never copy, and must raise an error
when a copy is deemed necessary. If ``copy=None`` then the function must
avoid a copy if possible but may copy if needed.
Returns:
A dlpack PyCapsule object.
A DLPack PyCapsule object.
Note:
While JAX arrays are always immutable, dlpack buffers cannot be marked as
immutable, and it is possible for processes external to JAX to mutate them
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
lead to undefined behavior when using the associated JAX array.
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
cannot be marked as immutable, and it is possible for processes external
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
is mutated, it may lead to undefined behavior when using the associated JAX
array. When JAX eventually supports ``DLManagedTensorVersioned``
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."
device = None
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
if dl_device_type:
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
backend = xla_bridge.get_backend(dl_device_platform)
device = backend.device_from_local_hardware_id(local_hardware_id)
except TypeError:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
# recommends using BufferError.
raise BufferError(
"The device specification passed to to_dlpack contains an unsupported "
f"device type (DLDeviceType: {dl_device_type})")
# As new versions are adopted over time, we can maintain some legacy paths
# for compatability mediated through the max_version parameter.
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
if max_version is None or max_version >= DLPACK_VERSION:
# Latest
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
elif max_version >= MIN_DLPACK_VERSION:
# Oldest supported
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
else:
raise BufferError(
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
f"version ({max_version}) was requested."
)
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore
def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr
def from_dlpack(external_array):
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
preferred_platform = getattr(device, "platform", None)
if device and preferred_platform == "gpu":
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
cpu_backend = xla_bridge.get_backend("cpu")
gpu_backend = None
if preferred_platform in {"cuda", "rocm"}:
try:
gpu_backend = xla_bridge.get_backend(preferred_platform)
except RuntimeError:
raise TypeError(
f"A {str.upper(preferred_platform)} device was specified, however no "
f"{str.upper(preferred_platform)} backend was found."
)
if preferred_platform is None:
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
pass
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
pass
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)
def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
backend = xla_bridge.get_backend(dl_device_platform)
dlpack_device = backend.device_from_local_hardware_id(device_id)
try:
stream = dlpack_device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)
def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
copy: bool | None = None):
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned :class:`~jax.Array` shares memory with ``external_array``.
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
device transfer or copy was requested.
Args:
external_array: an array object that has __dlpack__ and __dlpack_device__
external_array: An array object that has __dlpack__ and __dlpack_device__
methods, or a DLPack tensor on either CPU or GPU (legacy API).
device: The (optional) :py:class:`Device`, representing the device on which
the returned array should be placed. If given, then the result is committed
to the device. If unspecified, the resulting array will be unpacked onto the
same device it originated from. Setting ``device`` to a device different from
the source of ``external_array`` will require a copy, meaning ``copy`` must be
set to either ``True`` or ``None``.
copy: An (optional) boolean, controlling whether or not to a copy is performed.
If ``copy=True`` then a copy is always performed, even if unpacked onto the
same device. If ``copy=False`` then the copy is never peformed and will raise
an error if necessary. When ``copy=None`` then a copy may be performed if
needed for a device transfer.
Returns:
A jax.Array
@ -102,49 +267,16 @@ def from_dlpack(external_array):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
"""
if isinstance(device, Sharding):
device_set = device.device_set
if len(device_set) > 1:
raise ValueError(
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
f"a Sharding with {len(device_set)} devices was provided."
)
device, = device_set
if hasattr(external_array, "__dlpack__"):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
return _from_dlpack(external_array, device, copy)
backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
try:
stream = device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))
else:
# Legacy path
dlpack = external_array
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
# Legacy path
return _legacy_from_dlpack(external_array, device, copy)

110
jax/_src/earray.py Normal file
View File

@ -0,0 +1,110 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import tree_util
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.util import safe_zip, safe_map
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# EArray is an Array that can contain extended dtypes.
class EArray(basearray.Array):
__slots__ = ['aval', '_data']
__hash__ = None # type: ignore[assignment]
__array_priority__ = 100
def __init__(self, aval, data):
self.aval = aval
self._data = data
def block_until_ready(self):
_ = self._data.block_until_ready()
return self
def copy_to_host_async(self):
self._data.copy_to_host_async()
def copy(self):
return EArray(self.aval, self._data.copy())
def __repr__(self):
return 'E' + repr(self._data)
def __iter__(self):
if self.ndim == 0: raise TypeError('iteration over a 0-d array')
raise NotImplementedError
# forward to aval
shape = property(lambda self: self.aval.shape) # type: ignore[assignment]
dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment]
# computed from shape and dtype
ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment]
size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment]
itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment]
def __len__(self):
if self.ndim == 0: raise TypeError('len() of unsized object')
return self.shape[0]
# forward to self._data
devices = property(lambda self: self._data.devices) # type: ignore[assignment]
_committed = property(lambda self: self._data._committed)
is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment]
is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment]
delete = property(lambda self: self._data.delete) # type: ignore[assignment]
is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment]
on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment]
unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment]
# defer to extended dtype rules
@property
def sharding(self):
phys_sharding = self._data.sharding
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
def addressable_data(self, index: int) -> EArray:
raise NotImplementedError
@property
def addressable_shards(self):
raise NotImplementedError
@property
def global_shards(self):
raise NotImplementedError
# TODO(mattjj): _set_array_base_attributes
def _earray_shard_arg_handler(x, sharding):
arr = x._data
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(
EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))

View File

@ -23,6 +23,7 @@ import collections
import itertools
from typing import Union, cast
import jax
from jax import lax
from jax._src import dtypes
from jax._src import test_util
@ -30,8 +31,7 @@ from jax._src.util import safe_map, safe_zip
import numpy as np
from jax import config
config.parse_flags_with_absl()
jax.config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -75,6 +75,13 @@ _JAX_DUMP_IR_TO = config.DEFINE_string(
"Supports the special value 'sponge' to pick the path from the "
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
'jax_include_debug_info_in_dumps',
os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
help="Determine whether or not to keep debug symbols and location information "
"when dumping IR code. By default, debug information will be preserved in "
"the IR dump. To avoid exposing source code and potentially sensitive "
"information, set to false")
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
@ -474,9 +481,12 @@ def dump_module_message(module: ir.Module, stage_name: str) -> str:
def _make_string_safe_for_filename(s: str) -> str:
return re.sub(r'[^\w.)( -]', '', s)
def module_to_string(module: ir.Module) -> str:
def module_to_string(module: ir.Module, enable_debug_info=None) -> str:
output = io.StringIO()
module.operation.print(file=output, enable_debug_info=True)
if enable_debug_info is None:
enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value)
enable_debug_info = enable_debug_flag not in ('false', '0')
module.operation.print(file=output, enable_debug_info=enable_debug_info)
return output.getvalue()
def module_to_bytecode(module: ir.Module) -> bytes:
@ -944,11 +954,6 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
else in_layouts)
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
else out_layouts)
ctx = ModuleContext(backend_or_name=backend_or_name,
platforms=platforms, axis_context=axis_context,
keepalives=keepalives,
@ -982,8 +987,8 @@ def lower_jaxpr_to_module(
result_names=result_names,
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds,
arg_layouts=arg_layouts,
result_layouts=result_layouts)
arg_layouts=in_layouts,
result_layouts=out_layouts)
try:
if not ctx.module.operation.verify():
@ -1130,8 +1135,8 @@ def lower_jaxpr_to_fun(
result_names: Sequence[str | None] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
arg_layouts: Sequence[str | None] | None = None,
result_layouts: Sequence[str | None] | None = None,
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -1252,7 +1257,8 @@ def lower_jaxpr_to_fun(
ir_arg_layouts = None
if arg_layouts is not None:
ir_arg_layouts = util.flatten(
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
[[_to_xla_layout(l)] * len(types)
for l, types in zip(arg_layouts, input_types)])
ir_donated_args = None
if xla_donated_args is not None:
@ -1275,7 +1281,8 @@ def lower_jaxpr_to_fun(
ir_result_layouts = None
if result_layouts is not None:
ir_result_layouts = util.flatten(
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
[[_to_xla_layout(l)] * len(types)
for l, types in zip(result_layouts, output_types)])
if (
replicated_args is not None

View File

@ -1004,19 +1004,6 @@ class UnloadedPmapExecutable:
shards.out_sharded_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
input_indices = [
sharding_specs.spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
]
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
out_shardings)
return _compile_replicated_pmap_executable_from_hlo(
hlo, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects, jaxpr_debug_info)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
@ -1038,23 +1025,6 @@ class UnloadedPmapExecutable:
jaxpr_debug_info=jaxpr_debug_info).load()
def _compile_replicated_pmap_executable_from_hlo(
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
jaxpr_debug_info):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=hlo,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
jaxpr_debug_info, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
@ -1185,13 +1155,25 @@ class ExecuteReplicated:
def _handle_token_bufs(self, token_bufs, sharded_token):
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
# token buffer (as a singleton list).
# token buffers.
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
for i, device in enumerate(self._local_devices):
dispatch.runtime_tokens.set_output_runtime_token(
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
assert len(token_buf) > 0
if len(token_buf) == 1:
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
else:
token_devices = []
for token in token_buf:
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
token_devices.append(token.sharding._device_assignment[0])
s = sharding_impls.PositionalSharding(token_devices)
global_token_array = jax.make_array_from_single_device_arrays(
(0,), s, token_buf
)
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
@profiler.annotate_function
def __call__(self, *args):
@ -2007,10 +1989,8 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
"""Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
in_shardings: Any
in_layouts: Any
debug_info: core.JaxprDebugInfo | None
@ -2035,15 +2015,15 @@ def lower_sharding_computation(
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Sequence[MaybeSharding],
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
lowering_parameters: mlir.LoweringParameters,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
lowering_parameters: mlir.LoweringParameters
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -2056,8 +2036,7 @@ def lower_sharding_computation(
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
closed_jaxpr.jaxpr.debug_info)
all_args_info = AllArgsInfo(global_in_avals, closed_jaxpr.jaxpr.debug_info)
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _dce_jaxpr(
@ -2109,7 +2088,7 @@ def lower_sharding_computation(
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
if xla_extension_version < 241:
gs = GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
@ -2720,15 +2699,12 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
if hasattr(backend, "compile_replicated"):
return None, compile_options
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks)
return xla_executable, compile_options
return xla_executable
def _maybe_get_and_check_in_shardings(
@ -2758,21 +2734,16 @@ def _maybe_get_and_check_in_shardings(
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
# TODO(yashkatariya): Remove the if branch for abstract_token once
# choosing input shardings by XLA is enabled again.
if aval is core.abstract_token:
new_in_shardings.append(orig)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
return new_in_shardings
@ -2863,7 +2834,6 @@ class UnloadedMeshExecutable:
self.in_layouts, self.out_layouts,
self.all_args_info, self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
def from_hlo(name: str,
hlo: ir.Module,
@ -2916,24 +2886,12 @@ class UnloadedMeshExecutable:
mesh = i.mesh # type: ignore
break
xla_executable, compile_options = _cached_compilation(
xla_executable = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
if hasattr(backend, "compile_replicated"):
semantics_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
semantics_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
return _compile_replicated_mesh_executable_from_hlo(
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
compile_options, tuple(host_callbacks), bool(unordered_effects),
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps)
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
@ -3052,28 +3010,22 @@ class MeshExecutable(stages.XlaExecutable):
return self.xla_executable
def call(self, *args):
args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx]
if self._all_args_info is None:
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
kept_args = args_after_dce
ref_avals = self.in_avals
in_shardings = self._in_shardings
in_layouts = self._in_layouts
debug_info = None
else:
kept_args = args
ref_avals = self._all_args_info.in_avals
iter_in_shardings = iter(self._in_shardings)
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_shardings)]
iter_in_layouts = iter(self._in_layouts)
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_layouts)]
debug_info = self._all_args_info.debug_info
arg_avals = map(xla.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
all_arg_avals = map(xla.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_array_xla_sharding_layout_match(kept_args, in_shardings,
in_layouts, debug_info)
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._in_layouts, debug_info,
self._kept_var_idx)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@ -3182,35 +3134,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)
@weakref_lru_cache
def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
semantics_out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, da, committed, pmap_nreps):
assert not auto_spmd_lowering
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
device_assignment=da, ordered_effects=ordered_effects,
in_avals=global_in_avals,
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
global_out_avals, in_shardings, out_shardings,
auto_spmd_lowering, kept_var_idx,
(None,) * len(global_in_avals),
(None,) * len(global_out_avals))
@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
@ -3231,16 +3154,22 @@ def check_device_backend_on_shardings(shardings) -> bool:
def check_array_xla_sharding_layout_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
args_after_dce,
in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
jaxpr_debug_info: core.JaxprDebugInfo | None,
kept_var_idx: set[int]) -> None:
from jax._src.array import ArrayImpl
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
arg_names = (
[""] * len(args_after_dce) if jaxpr_debug_info is None
else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
if i in kept_var_idx]
)
errors = []
num_errors = 5
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
arg_names):
for arg, xs, xl, name in safe_zip(
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
@ -3271,8 +3200,9 @@ def check_array_xla_sharding_layout_match(
arg.layout.device_local_layout != xl):
errors.append(
("Got input layout(s) that compiled object was called with: "
f"{arg.layout} and layout(s) the computation was compiled "
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
f"{arg.layout.device_local_layout} and layout(s) the computation was "
f"compiled with: {xl} for arg {name} with "
f"shape: {arg.aval.str_short()}",
'layout'))
if errors:
@ -3362,12 +3292,3 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
yield
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> list[xc.ArrayImpl]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return [jax.device_put(x, device) for device in devices]
else:
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]

View File

@ -69,7 +69,7 @@ _no_operand_sentinel = object()
@api_boundary
def switch(index, branches: Sequence[Callable], *operands,
operand=_no_operand_sentinel):
"""Apply exactly one of ``branches`` given by ``index``.
"""Apply exactly one of the ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.

View File

@ -49,7 +49,7 @@ def _split_root_args(args, const_lengths):
@api_boundary
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
"""Differentiably solve for a roots of a function.
"""Differentiably solve for the roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of custom_root() are defined with respect to closed-over variables

View File

@ -22,7 +22,7 @@ from functools import partial
import itertools
import math
import operator
from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
import warnings
import numpy as np
@ -625,14 +625,14 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
_precision_strings: dict[Any, Precision] = {}
# TODO(b/328046715): pytype appears unable to handle overriding __new__ in an
# enum class. Doing this crashes Pytype. For now, just write an explicit type
# for type checkers.
# TODO(b/333851820): pytype does not properly handle _missing_ in enums.
# We work around that by defining `Precision` as a normal class.
if TYPE_CHECKING:
class Precision:
DEFAULT: Precision
HIGH: Precision
HIGHEST: Precision
DEFAULT: ClassVar[Precision]
HIGH: ClassVar[Precision]
HIGHEST: ClassVar[Precision]
def __new__(cls, value: Precision | int | str | None) -> Precision:
raise NotImplementedError
@ -646,6 +646,7 @@ if TYPE_CHECKING:
raise NotImplementedError
else:
class Precision(enum.Enum):
"""Precision enum for lax functions
@ -664,23 +665,21 @@ else:
Slowest but most accurate. Performs computations in float32 or float64
as applicable. Aliases: ``'highest'``, ``'float32'``.
"""
DEFAULT = 0
HIGH = 1
HIGHEST = 2
@classmethod
def _missing_(cls, value: object) -> Precision | None:
return _precision_strings.get(value)
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
return f'{self.__class__.__name__}.{self.name}'
def __str__(self) -> str:
return self.name
# You can't define __new__ on an enum class directly, but you can monkey-patch
# it after the fact. Another way to do this might be using a metaclass.
def _precision_new(cls, value: Precision | int | str | None) -> Precision:
return super(Precision, cls).__new__(cls, _precision_strings.get(value, value))
Precision.__new__ = _precision_new
_precision_strings['highest'] = Precision.HIGHEST
_precision_strings['float32'] = Precision.HIGHEST

View File

@ -79,7 +79,7 @@ def psum(x, axis_name, *, axis_index_groups=None):
>>> print(y)
[0. 0.16666667 0.33333334 0.5 ]
Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with `device2` and `device3`,
Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with ``device2`` and ``device3``,
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)

View File

@ -232,7 +232,7 @@ class GatherDimensionNumbers(NamedTuple):
in the output of the gather. Must be a tuple of integers in ascending
order.
start_index_map: for each dimension in `start_indices`, gives the
corresponding dimension in `operand` that is to be sliced. Must be a
corresponding dimension in the `operand` that is to be sliced. Must be a
tuple of integers with size equal to `start_indices.shape[-1]`.
Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is
@ -261,8 +261,8 @@ class GatherScatterMode(enum.Enum):
will be discarded.
PROMISE_IN_BOUNDS:
The user promises that indices are in bounds. No additional checking will be
performed. In practice, with the current XLA implementation this means
that, out-of-bounds gathers will be clamped but out-of-bounds scatters will
performed. In practice, with the current XLA implementation this means
that out-of-bounds gathers will be clamped but out-of-bounds scatters will
be discarded. Gradients will not be correct if indices are out-of-bounds.
"""
CLIP = enum.auto()

View File

@ -72,16 +72,18 @@ class Layout:
)
if not isinstance(
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
raise ValueError(
raise TypeError(
'Invalid value received for the device_local_layout argument.'
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an instance'
f' of `DeviceLocalLayout`. Got {device_local_layout}')
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an'
f' instance of `DeviceLocalLayout`. Got {device_local_layout} of'
f' type {type(device_local_layout)}'
)
if not isinstance(
sharding, (Sharding, type(None), AutoSharding)):
raise ValueError(
raise TypeError(
'Invalid value received for the sharding argument. Expected values'
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
f' {sharding}')
f' {sharding} of type {type(sharding)}')
self.device_local_layout = device_local_layout
self.sharding = sharding

View File

@ -714,9 +714,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
return pxla.lower_sharding_computation(
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, in_avals, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
devices_from_context=None, lowering_parameters=lowering_parameters)
class EvaluationPlan(NamedTuple):

View File

@ -20,6 +20,7 @@ from functools import partial
import operator
import numpy as np
from typing import Any
import warnings
import jax
import jax.numpy as jnp
@ -35,6 +36,12 @@ from jax._src.typing import Array, ArrayLike
from jax._src.ops.special import logsumexp as _logsumexp
class Unspecified:
def __repr__(self):
return "_UNSPECIFIED"
_UNSPECIFIED = Unspecified()
# activations
@custom_jvp
@ -173,6 +180,38 @@ def sigmoid(x: ArrayLike) -> Array:
"""
return lax.logistic(x)
@jax.jit
def sparse_sigmoid(x: ArrayLike) -> Array:
r"""Sparse sigmoid activation function.
Computes the function:
.. math::
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{2}(x+1), & -1 < x < 1 \\
1, & 1 \leq x
\end{cases}
This is the twin function of the ``sigmoid`` activation ensuring a zero output
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
<https://arxiv.org/abs/1901.02324>`_.
Args:
x : input array
Returns:
An array.
See also:
:func:`sigmoid`
"""
return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0)
@jax.jit
def silu(x: ArrayLike) -> Array:
r"""SiLU (aka swish) activation function.
@ -454,7 +493,7 @@ logsumexp = _logsumexp
def log_softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = None) -> Array:
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
@ -469,8 +508,6 @@ def log_softmax(x: ArrayLike,
axis: the axis or axes along which the :code:`log_softmax` should be
computed. Either an integer or a tuple of integers.
where: Elements to include in the :code:`log_softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
Returns:
An array.
@ -478,10 +515,15 @@ def log_softmax(x: ArrayLike,
See also:
:func:`softmax`
"""
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
del initial
numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x)
x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True)
x_safe = x_arr if where is None else jnp.where(where, x_arr, initial)
x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True)
x_safe = x_arr if where is None else jnp.where(where, x_arr, -jnp.inf)
shifted = x_safe - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
@ -496,7 +538,7 @@ def log_softmax(x: ArrayLike,
def softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = None) -> Array:
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
@ -511,8 +553,6 @@ def softmax(x: ArrayLike,
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
where: Elements to include in the :code:`softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
Returns:
An array.
@ -520,13 +560,18 @@ def softmax(x: ArrayLike,
See also:
:func:`log_softmax`
"""
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
del initial
if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition
# of `_softmax` and incorrectly concludes that `_softmax` returns
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
return _softmax(x, axis, where, initial) # type: ignore[return-value]
return _softmax(x, axis, where) # type: ignore[return-value]
else:
return _softmax_deprecated(x, axis, where, initial)
return _softmax_deprecated(x, axis, where)
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(jax.custom_jvp, nondiff_argnums=(1,))
@ -534,7 +579,7 @@ def _softmax(
x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = None) -> Array:
initial: ArrayLike | None = -jnp.inf) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - x_max)
@ -553,7 +598,7 @@ def _softmax_deprecated(
x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None = None) -> Array:
initial: ArrayLike | None = -jnp.inf) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max))

View File

@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:
def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None,
out: None = None) -> Array:
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation."""
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
return lax_numpy.clip(number, min=min, max=max)
def _transpose(a: Array, *args: Any) -> Array:

View File

@ -66,7 +66,10 @@ from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis,
@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)
@util.implements(np.clip, skip_params=['out'])
_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
a_max: ArrayLike | None = None, out: None = None) -> Array:
util.check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = ufuncs.maximum(a_min, a)
if a_max is not None:
a = ufuncs.minimum(a_max, a)
return asarray(a)
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
) -> Array:
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)
util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)
@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
@ -2217,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
util.check_arraylike("astype", x)
x_arr = asarray(x)
del copy # unused in JAX
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
# convert_element_type(complex, bool) has the wrong semantics.
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
return (x_arr != _lax_const(x_arr, 0))
return lax.convert_element_type(x_arr, dtype)
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
@ -2442,9 +2493,10 @@ def fromiter(*args, **kwargs):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
""")
def from_dlpack(x: Any) -> Array:
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
copy: bool | None = None) -> Array:
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x)
return from_dlpack(x, device=device, copy=copy)
@util.implements(np.fromfunction)
def fromfunction(function: Callable[..., Array], shape: Any,
@ -4889,8 +4941,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
"with type {} at position {}, indexer value {}")
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
msg = "Indexing mode not yet supported. Open a feature request!\n{}"
raise IndexError(msg.format(idx))
raise IndexError("Indexing mode not yet supported. Got unsupported indexer "
f"at position {idx_pos}: {i!r}")
if len(gather_indices) == 0:
gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype)

Some files were not shown because too many files have changed in this diff Show More