mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
This commit is contained in:
parent
69cd3ebe99
commit
d872812a35
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -159,7 +159,7 @@ jobs:
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short docs
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic/__init__.py --ignore=jax/experimental/mosaic/dialects.py
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
||||
|
||||
|
||||
documentation_render:
|
||||
|
@ -2,6 +2,7 @@ absl-py
|
||||
build
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
hypothesis
|
||||
numpy>=1.22
|
||||
pillow>=9.1.0
|
||||
portpicker
|
||||
|
62
jax/BUILD
62
jax/BUILD
@ -22,6 +22,8 @@ load(
|
||||
"jax_test_util_visibility",
|
||||
"jax_visibility",
|
||||
"mosaic_internal_users",
|
||||
"pallas_gpu_internal_users",
|
||||
"pallas_tpu_internal_users",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_library",
|
||||
@ -67,6 +69,20 @@ package_group(
|
||||
] + mosaic_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "pallas_gpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
] + pallas_gpu_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "pallas_tpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
] + pallas_tpu_internal_users,
|
||||
)
|
||||
|
||||
# JAX-private test utilities.
|
||||
py_library(
|
||||
# This build target is required in order to use private test utilities in jax._src.test_util,
|
||||
@ -430,6 +446,52 @@ pytype_strict_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas",
|
||||
srcs = glob(
|
||||
[
|
||||
"experimental/pallas/**/*.py",
|
||||
],
|
||||
exclude = [
|
||||
"experimental/pallas/gpu.py",
|
||||
"experimental/pallas/tpu.py",
|
||||
],
|
||||
),
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":jax",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_tpu",
|
||||
srcs = ["experimental/pallas/tpu.py"],
|
||||
visibility = [
|
||||
":pallas_tpu_users",
|
||||
],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/mosaic",
|
||||
"//jax/_src/pallas/mosaic:core",
|
||||
"//jax/_src/pallas/mosaic:primitives",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_gpu",
|
||||
srcs = ["experimental/pallas/gpu.py"],
|
||||
visibility = [
|
||||
":pallas_gpu_users",
|
||||
],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/triton",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "partial_eval",
|
||||
srcs = ["_src/interpreters/partial_eval.py"],
|
||||
|
@ -40,15 +40,20 @@ py_library_providing_imports_info(
|
||||
"//jaxlib",
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
"//jaxlib:utils",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
"//jaxlib/mlir:chlo_dialect",
|
||||
"//jaxlib/mlir:func_dialect",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:math_dialect",
|
||||
"//jaxlib/mlir:memref_dialect",
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:ml_program_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
# xla_client
|
||||
],
|
||||
"//conditions:default": [],
|
||||
|
@ -103,6 +103,7 @@ import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
|
||||
|
||||
# Jaxlib code is split between the Jax and the Tensorflow repositories.
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
|
@ -20,5 +20,14 @@ import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.ml_program as ml_program
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
|
||||
from jax._src import lib
|
||||
# TODO(sharadmv): remove guard when minimum jaxlib version is bumped
|
||||
if lib.version >= (0, 4, 15):
|
||||
import jaxlib.mlir.dialects.arith as arith
|
||||
import jaxlib.mlir.dialects.math as math
|
||||
import jaxlib.mlir.dialects.memref as memref
|
||||
import jaxlib.mlir.dialects.scf as scf
|
||||
import jaxlib.mlir.dialects.vector as vector
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
64
jax/_src/pallas/BUILD
Normal file
64
jax/_src/pallas/BUILD
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pallas",
|
||||
srcs = glob(
|
||||
include = ["**/*.py"],
|
||||
exclude = [
|
||||
"triton/*.py",
|
||||
"mosaic/*.py",
|
||||
],
|
||||
),
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:partial_eval",
|
||||
"//jax:pretty_printer",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu",
|
||||
visibility = [],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/triton",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu",
|
||||
visibility = [],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/mosaic",
|
||||
],
|
||||
)
|
13
jax/_src/pallas/__init__.py
Normal file
13
jax/_src/pallas/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
229
jax/_src/pallas/core.py
Normal file
229
jax/_src/pallas/core.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
import jax.numpy as jnp
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Grid = tuple[int, ...]
|
||||
split_list = util.split_list
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GridEnv:
|
||||
axis_index: Any
|
||||
axis_size: int
|
||||
|
||||
_grid_env_stack: list[tuple[GridEnv, ...]] = []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def grid_env(env: tuple[tuple[Any, int], ...]) -> Iterator[None]:
|
||||
_grid_env_stack.append(tuple(GridEnv(axis_index, axis_size)
|
||||
for axis_index, axis_size in env))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_grid_env_stack.pop()
|
||||
|
||||
|
||||
def current_grid_env() -> tuple[GridEnv, ...] | None:
|
||||
if not _grid_env_stack:
|
||||
return None
|
||||
return _grid_env_stack[-1]
|
||||
|
||||
|
||||
class Mapped:
|
||||
pass
|
||||
mapped = Mapped()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BlockSpec:
|
||||
index_map: Callable[..., Any]
|
||||
block_shape: tuple[int | None, ...]
|
||||
|
||||
def compute_index(self, *args):
|
||||
out = self.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
return out
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BlockMapping:
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
|
||||
def compute_start_indices(self, loop_idx, *args):
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
|
||||
)
|
||||
jaxpr = jax_core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
||||
block_indices_and_rest = jax_core.jaxpr_as_fun(jaxpr)(*loop_idx, *args)
|
||||
# Since we're passing in `Ref`s potentially, we need to split out their
|
||||
# updated values since we only care about the return values.
|
||||
block_indices, _ = split_list(block_indices_and_rest,
|
||||
[len(self.block_shape)])
|
||||
return tuple(i if b is mapped else b * i
|
||||
for b, i in zip(self.block_shape, block_indices))
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GridMapping:
|
||||
grid: tuple[int, ...]
|
||||
block_mappings: tuple[BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...]
|
||||
num_index_operands: int
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
if grid is None:
|
||||
return ()
|
||||
if isinstance(grid, int):
|
||||
return (grid,)
|
||||
return grid
|
||||
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
|
||||
) -> BlockSpec | None:
|
||||
if block_spec is _no_block_spec:
|
||||
return None
|
||||
block_shape = tuple(
|
||||
mapped if s is None else s for s in block_spec.block_shape)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(block_spec.compute_index), in_avals)
|
||||
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
|
||||
|
||||
|
||||
def _compute_shape_from_block_spec(block_spec: BlockSpec | None,
|
||||
arg_shape: tuple[int, ...]
|
||||
) -> tuple[int, ...]:
|
||||
if block_spec is _no_block_spec:
|
||||
return arg_shape
|
||||
return tuple(s for s in block_spec.block_shape if s is not None)
|
||||
|
||||
|
||||
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
|
||||
if grid is None:
|
||||
in_specs = [None] * len(in_avals)
|
||||
out_specs = [None] * len(out_avals)
|
||||
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
|
||||
for arg in in_avals]
|
||||
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
|
||||
for arg in out_avals]
|
||||
else:
|
||||
in_ref_avals = [
|
||||
state.shaped_array_ref(
|
||||
_compute_shape_from_block_spec(
|
||||
block_spec, arg.shape), arg.dtype)
|
||||
for block_spec, arg in zip(in_specs, in_avals)]
|
||||
out_ref_avals = [
|
||||
state.shaped_array_ref(
|
||||
_compute_shape_from_block_spec(
|
||||
block_spec, arg.shape), arg.dtype)
|
||||
for block_spec, arg in zip(out_specs, out_avals)]
|
||||
return in_specs, in_ref_avals, out_specs, out_ref_avals
|
||||
|
||||
|
||||
_no_block_spec = object()
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
class GridSpec:
|
||||
grid: Grid
|
||||
in_specs: Sequence[BlockSpec | None] | None
|
||||
out_specs: tuple[BlockSpec | None, ...] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
):
|
||||
if grid is None:
|
||||
if in_specs is not None:
|
||||
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
|
||||
if out_specs is not None:
|
||||
raise ValueError("Cannot specify `out_specs` with a `None` grid.")
|
||||
self.grid = _preprocess_grid(grid)
|
||||
self.in_specs = in_specs
|
||||
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
|
||||
out_specs = (out_specs,)
|
||||
if out_specs is not None and not isinstance(out_specs, tuple):
|
||||
out_specs = tuple(out_specs)
|
||||
self.out_specs = out_specs
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
if self.in_specs is not None:
|
||||
in_specs = self.in_specs
|
||||
in_spec_tree = tree_util.tree_structure(tuple(in_specs))
|
||||
if in_spec_tree != in_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for arguments and `in_specs` must match: "
|
||||
f"{in_tree} vs. {in_spec_tree}")
|
||||
else:
|
||||
in_specs = [_no_block_spec] * len(in_avals)
|
||||
if self.out_specs is not None:
|
||||
out_specs = self.out_specs
|
||||
out_spec_tree = tree_util.tree_structure(out_specs)
|
||||
if out_spec_tree != out_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for `out_shape` and `out_specs` must match: "
|
||||
f"{out_tree} vs. {out_spec_tree}")
|
||||
else:
|
||||
out_specs = [_no_block_spec] * len(out_avals)
|
||||
flat_in_specs = tree_util.tree_leaves(in_specs)
|
||||
flat_out_specs = tree_util.tree_leaves(out_specs)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs, out_avals,
|
||||
flat_out_specs)
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
in_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs)
|
||||
out_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs)
|
||||
grid_mapping = GridMapping(
|
||||
self.grid, (*in_block_mappings, *out_block_mappings), (),
|
||||
num_index_operands=0)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
158
jax/_src/pallas/indexing.py
Normal file
158
jax/_src/pallas/indexing.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import tree_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
|
||||
# We could use `jnp.broadcast_to` but that lowers to squeezing,
|
||||
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
|
||||
# so in the lowering, we have to expand out those squeezed dimensions again.
|
||||
# Having a simple `broadcast_to` primitive allows us to lower directly
|
||||
# to `tl.broadcast_to`.
|
||||
broadcast_to_p = jax_core.Primitive('broadcast_to')
|
||||
|
||||
def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array:
|
||||
if a.shape == shape:
|
||||
return a
|
||||
return broadcast_to_p.bind(a, shape=shape)
|
||||
|
||||
@broadcast_to_p.def_impl
|
||||
def _broadcast_to_impl(a, *, shape):
|
||||
return jnp.broadcast_to(a, shape)
|
||||
|
||||
@broadcast_to_p.def_abstract_eval
|
||||
def _broadcast_to_abstract_eval(aval, *, shape):
|
||||
return jax_core.ShapedArray(shape, aval.dtype)
|
||||
|
||||
mlir.register_lowering(
|
||||
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
||||
)
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
start: Any
|
||||
size: int
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
if isinstance(self.start, int):
|
||||
return (), (True, self.start, self.size)
|
||||
return (self.start,), (False, self.size)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, data, xs):
|
||||
is_static = data[0]
|
||||
if is_static:
|
||||
del xs
|
||||
start, size = data[1:]
|
||||
return Slice(start, size)
|
||||
start, = xs
|
||||
size = data[1]
|
||||
return Slice(start, size)
|
||||
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
start, stop = slc.start, slc.stop
|
||||
start = 0 if start is None else start
|
||||
stop = size if stop is None else stop
|
||||
return Slice(start, stop - start)
|
||||
|
||||
|
||||
def dslice(start: int | jax.Array | None, size: int | None = None
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
if start is None:
|
||||
return slice(None)
|
||||
if size is None:
|
||||
if not isinstance(start, int):
|
||||
raise ValueError("Non-static `dslice`")
|
||||
return Slice(0, start)
|
||||
return Slice(start, size)
|
||||
ds = dslice # Handy alias
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class NDIndexer:
|
||||
indices: Tuple[int | Slice | jax.Array, ...]
|
||||
shape: Tuple[int, ...]
|
||||
int_indexer_shape: Tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.indices) != len(self.shape):
|
||||
raise ValueError("`indices` must be the same length as `Ref` shape.")
|
||||
|
||||
def tree_flatten(self):
|
||||
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
|
||||
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
|
||||
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
|
||||
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
|
||||
self.int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, data, flat_idx):
|
||||
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
|
||||
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
|
||||
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
|
||||
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
||||
if len(indices) > len(shape):
|
||||
raise ValueError("`indices` must be the no longer than `shape`.")
|
||||
# Pad out indices with slice(None)
|
||||
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
||||
# Convert all `slice`s to `Slice`s
|
||||
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
||||
else i for i, s in zip(indices, shape))
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
|
||||
int_indexers]
|
||||
indexer_shapes = [i.shape for i in int_indexers]
|
||||
if indexer_shapes:
|
||||
try:
|
||||
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
||||
except ValueError as e:
|
||||
# Raise a nicer error than the NumPy one.
|
||||
raise ValueError("Cannot broadcast shapes for indexing: "
|
||||
f"{tuple(a for a in indexer_shapes)}") from e
|
||||
else:
|
||||
bcast_shape = ()
|
||||
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape)
|
||||
|
||||
def get_indexer_shape(self) -> Tuple[int, ...]:
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
|
||||
other_indexers, _ = partition_list(is_int_indexing, self.indices)
|
||||
other_shape = [s.size for s in other_indexers] # type: ignore
|
||||
return tuple((*self.int_indexer_shape, *other_shape))
|
89
jax/_src/pallas/mosaic/BUILD
Normal file
89
jax/_src/pallas/mosaic/BUILD
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Package for Mosaic-specific Pallas extensions
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "mosaic",
|
||||
srcs = ["__init__.py"],
|
||||
lib_rule = py_library,
|
||||
deps = [
|
||||
":core",
|
||||
":pallas_call_registration",
|
||||
":primitives",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "primitives",
|
||||
srcs = ["primitives.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pallas_call_registration",
|
||||
srcs = ["pallas_call_registration.py"],
|
||||
deps = [
|
||||
":lowering",
|
||||
"//jax",
|
||||
"//jax:mosaic",
|
||||
"//jax:source_info_util",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "lowering",
|
||||
srcs = ["lowering.py"],
|
||||
deps = [
|
||||
":core",
|
||||
":primitives",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:mosaic",
|
||||
"//jax:partial_eval",
|
||||
"//jax:source_info_util",
|
||||
"//jax:util",
|
||||
"//jax:xla",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas",
|
||||
] + py_deps("numpy"),
|
||||
)
|
28
jax/_src/pallas/mosaic/__init__.py
Normal file
28
jax/_src/pallas/mosaic/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for Mosaic lowering of Pallas call."""
|
||||
|
||||
from jax._src.pallas.mosaic import core
|
||||
from jax._src.pallas.mosaic import pallas_call_registration
|
||||
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import trace
|
||||
|
||||
|
||||
VMEM = TPUMemorySpace.VMEM
|
||||
SMEM = TPUMemorySpace.SMEM
|
||||
CMEM = TPUMemorySpace.CMEM
|
||||
del pallas_call_registration
|
104
jax/_src/pallas/mosaic/core.py
Normal file
104
jax/_src/pallas/mosaic/core.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains TPU-specific Pallas abstractions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Grid = pallas_core.Grid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
GridMapping = pallas_core.GridMapping
|
||||
_preprocess_grid = pallas_core._preprocess_grid
|
||||
_compute_shape_from_block_spec = pallas_core._compute_shape_from_block_spec
|
||||
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
|
||||
split_list = util.split_list
|
||||
|
||||
|
||||
class TPUMemorySpace(enum.Enum):
|
||||
VMEM = "vmem"
|
||||
SMEM = "smem"
|
||||
CMEM = "cmem"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: Grid
|
||||
num_scalar_prefetch: int
|
||||
in_specs: Sequence[BlockSpec | None] | None
|
||||
out_specs: tuple[BlockSpec | None, ...] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_scalar_prefetch: int,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
):
|
||||
if grid is None:
|
||||
raise NotImplementedError("Should pass in non-`None` grid.")
|
||||
self.grid = _preprocess_grid(grid)
|
||||
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
|
||||
out_specs = (out_specs,)
|
||||
if out_specs is not None and not isinstance(out_specs, tuple):
|
||||
out_specs = tuple(out_specs)
|
||||
self.num_scalar_prefetch = num_scalar_prefetch
|
||||
self.in_specs = in_specs
|
||||
self.out_specs = out_specs
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
scalar_avals, in_avals = split_list(in_avals, [self.num_scalar_prefetch])
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = (
|
||||
pallas_core._get_ref_avals(
|
||||
self.grid, in_avals, self.in_specs,
|
||||
out_avals, self.out_specs))
|
||||
scalar_ref_avals = [
|
||||
state.shaped_array_ref(aval.shape, aval.dtype)
|
||||
for aval in scalar_avals]
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
in_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals)), in_specs)
|
||||
out_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals)), out_specs)
|
||||
grid_mapping = GridMapping(
|
||||
grid=self.grid,
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
mapped_dims=(),
|
||||
num_index_operands=self.num_scalar_prefetch,
|
||||
)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(
|
||||
in_tree, [*scalar_ref_avals, *in_ref_avals])
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
1264
jax/_src/pallas/mosaic/lowering.py
Normal file
1264
jax/_src/pallas/mosaic/lowering.py
Normal file
File diff suppressed because it is too large
Load Diff
76
jax/_src/pallas/mosaic/pallas_call_registration.py
Normal file
76
jax/_src/pallas/mosaic/pallas_call_registration.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains registrations for pallas_call on TPU."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core
|
||||
from jax._src.pallas.mosaic import lowering
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
|
||||
|
||||
def pallas_call_tpu_lowering_rule(
|
||||
ctx: mlir.LoweringRuleContext, *in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
which_linear: tuple[bool, ...],
|
||||
grid_mapping: core.GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
mosaic_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError(
|
||||
"`input_output_aliases` not supported on TPU backend.")
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
which_linear=which_linear,
|
||||
interpret=interpret, debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping, **compiler_params)
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
with ir.Context() as mlir_ctx, ir.Location.unknown():
|
||||
tpu.register_dialect(mlir_ctx)
|
||||
if mosaic_params is None:
|
||||
mosaic_params = {}
|
||||
dimension_semantics = mosaic_params.get("dimension_semantics", None)
|
||||
mosaic_module = lowering.lower_jaxpr_to_module(
|
||||
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics)
|
||||
if debug:
|
||||
print(mosaic_module)
|
||||
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
|
||||
return mlir.lower_fun(
|
||||
mosaic.as_tpu_kernel(
|
||||
mosaic_module, out_avals, backend=ctx.module_context.backend
|
||||
),
|
||||
multiple_results=True,
|
||||
)(ctx, *in_nodes)
|
||||
mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule,
|
||||
platform="tpu")
|
66
jax/_src/pallas/mosaic/primitives.py
Normal file
66
jax/_src/pallas/mosaic/primitives.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for Pallas:TPU-specific JAX primitives and functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
repeat_p = jax_core.Primitive('repeat')
|
||||
|
||||
def repeat(x, repeats, axis):
|
||||
return repeat_p.bind(x, repeats=repeats, axis=axis)
|
||||
|
||||
@repeat_p.def_abstract_eval
|
||||
def _repeat_abstract_eval(x, *, repeats, axis):
|
||||
shape = list(x.shape)
|
||||
shape[axis] *= repeats
|
||||
return jax_core.ShapedArray(shape, x.dtype)
|
||||
|
||||
|
||||
def _repeat_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, repeats, axis):
|
||||
def _repeat(x):
|
||||
return jnp.repeat(x, repeats, axis)
|
||||
return mlir.lower_fun(_repeat, multiple_results=False)(ctx, x)
|
||||
mlir.register_lowering(repeat_p, _repeat_lowering_rule)
|
||||
|
||||
trace_start_p = jax_core.Primitive('trace_start')
|
||||
trace_start_p.multiple_results = True
|
||||
|
||||
|
||||
@trace_start_p.def_abstract_eval
|
||||
def _trace_start_abstract_eval(*, message: str, level: int):
|
||||
del message, level
|
||||
return []
|
||||
|
||||
|
||||
trace_stop_p = jax_core.Primitive('trace_stop')
|
||||
trace_stop_p.multiple_results = True
|
||||
|
||||
|
||||
@trace_stop_p.def_abstract_eval
|
||||
def _trace_stop_abstract_eval():
|
||||
return []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def trace(message: str, level: int = 10):
|
||||
trace_start_p.bind(message=message, level=level)
|
||||
yield
|
||||
trace_stop_p.bind()
|
372
jax/_src/pallas/pallas_call.py
Normal file
372
jax/_src/pallas/pallas_call.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for calling pallas functions from JAX."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
|
||||
from typing import Any, Callable, Dict, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.util import (
|
||||
split_list, safe_map, safe_zip, weakref_lru_cache,
|
||||
tuple_insert, partition_list)
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Grid = pallas_core.Grid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
GridSpec = pallas_core.GridSpec
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
GridMapping = pallas_core.GridMapping
|
||||
|
||||
pallas_call_p = jax_core.Primitive('pallas_call')
|
||||
pallas_call_p.multiple_results = True
|
||||
|
||||
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
||||
if start_idx is None:
|
||||
assert is_indexing is None
|
||||
return value
|
||||
assert is_indexing is not None
|
||||
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
|
||||
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
|
||||
dtype=np.bool_)])
|
||||
return lax.squeeze(output, squeeze_dims)
|
||||
|
||||
def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
|
||||
is_indexing):
|
||||
if start_idx is None:
|
||||
assert is_indexing is None
|
||||
return update
|
||||
assert is_indexing is not None
|
||||
broadcast_dims = tuple(i for i, b in enumerate(is_indexing)
|
||||
if not b)
|
||||
update = lax.broadcast_in_dim(update, block_shape, broadcast_dims)
|
||||
assert update.shape == block_shape
|
||||
return lax.dynamic_update_slice(value, update, start_idx)
|
||||
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
interpret, debug: bool,
|
||||
in_shapes,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
**compiler_params: Any):
|
||||
if interpret:
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
||||
# will do.
|
||||
grid = grid_mapping.grid
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))))
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
out = []
|
||||
for i, out_shape in enumerate(out_shapes):
|
||||
if i in oi_map:
|
||||
out.append(args[oi_map[i]])
|
||||
else:
|
||||
out.append(jnp.zeros(out_shape.shape, out_shape.dtype))
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
carry = [*args, *out]
|
||||
def cond(carry):
|
||||
return carry[0] < loop_indices.shape[0]
|
||||
def body(carry):
|
||||
i, *carry = carry
|
||||
loop_idx = loop_indices[i]
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
block_shapes_without_mapped_dims = [
|
||||
None if block_mapping is None else block_mapping.block_shape
|
||||
for block_mapping in grid_mapping.block_mappings
|
||||
]
|
||||
is_indexing_dim = [
|
||||
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
||||
for bm in block_shapes_without_mapped_dims
|
||||
]
|
||||
block_shapes = [
|
||||
None if bm is None else tuple(1 if i else b for i, b in zip(iid, bm))
|
||||
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
||||
]
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
||||
is_indexing_dim)
|
||||
is_mapped_grid_dim = [
|
||||
i in grid_mapping.mapped_dims for i in range(len(grid_mapping.grid))]
|
||||
local_grid_env, _ = partition_list(is_mapped_grid_dim,
|
||||
zip(loop_idx, grid_mapping.grid))
|
||||
with pallas_core.grid_env(tuple(local_grid_env)):
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
||||
*blocks)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, *carry)
|
||||
(_, *carry) = lax.while_loop(cond, body, (0, *carry))
|
||||
_, out = split_list(carry, [len(args)])
|
||||
return out
|
||||
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes, which_linear=which_linear,
|
||||
grid_mapping=grid_mapping, interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
**compiler_params)
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("JVP with aliasing not supported.")
|
||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||
tangents = [ad.instantiate_zeros(t) if inst else t
|
||||
for t, inst in zip(tangents, nonzero_tangents)]
|
||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
|
||||
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
||||
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
||||
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
||||
jvp_which_linear = which_linear + (True,) * len(tangents)
|
||||
jvp_inshapes = (*in_shapes, *in_shapes)
|
||||
jvp_outshapes = (*out_shapes, *out_shapes)
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("`input_output_aliases` jvp not supported.")
|
||||
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
|
||||
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
|
||||
# `Ref`s that are read from followed by output `Ref`s that are written to.
|
||||
# This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new
|
||||
# jaxpr that has tangents following primals. In order for this jaxpr to be
|
||||
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
||||
# the jaxpr's invars.
|
||||
logical_primals, logical_tangents = split_list(
|
||||
jvp_jaxpr.invars, [len(primals) + len(out_shapes)])
|
||||
logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)])
|
||||
logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)])
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms))
|
||||
new_grid_mapping = grid_mapping.replace(block_mappings=new_bms)
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs,
|
||||
*logical_tangent_inputs,
|
||||
*logical_primal_outputs,
|
||||
*logical_tangent_outputs])
|
||||
if debug:
|
||||
print(jvp_jaxpr)
|
||||
out_flat = pallas_call_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
|
||||
name=f"{name}_jvp",
|
||||
in_shapes=jvp_inshapes,
|
||||
out_shapes=jvp_outshapes,
|
||||
grid_mapping=new_grid_mapping,
|
||||
which_linear=jvp_which_linear,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
**compiler_params)
|
||||
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
||||
def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping | None) -> BlockMapping:
|
||||
def _block_map_function(new_idx, *args):
|
||||
if block_mapping is None:
|
||||
indices = [0] * len(aval.shape)
|
||||
else:
|
||||
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
||||
block_mapping.index_map_jaxpr.consts,
|
||||
*args)
|
||||
if dim is not batching.not_mapped:
|
||||
indices.insert(dim, new_idx)
|
||||
return tuple(indices)
|
||||
i32_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
if block_mapping is None:
|
||||
idx_avals = [i32_aval] * (len(grid) + 1)
|
||||
else:
|
||||
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
shape = aval.shape if block_mapping is None else block_mapping.block_shape
|
||||
if dim is batching.not_mapped:
|
||||
new_block_shape = shape
|
||||
else:
|
||||
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
||||
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
||||
if block_mapping is None:
|
||||
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr)
|
||||
return block_mapping.replace(block_shape=new_block_shape,
|
||||
index_map_jaxpr=jaxpr)
|
||||
|
||||
def _pallas_call_batching_rule(args, dims, *,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: Tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
|
||||
if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims):
|
||||
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
|
||||
raise NotImplementedError
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
block_mappings = grid_mapping.block_mappings
|
||||
avals = [v.aval for v in jaxpr.invars]
|
||||
# How should we pick output dimensions? This actually matters because XLA
|
||||
# can't optimize our pallas kernels, and this layout impacts performance. For
|
||||
# now, because `vmap` doesn't really offer a way of inferring good output
|
||||
# dimensions. For now, we just use 0.
|
||||
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
|
||||
# TODO(sharadmv): explore a long term solution to output dim inference
|
||||
|
||||
# When we have input/output aliasing, since the output will be mapped, we need
|
||||
# to make sure to broadcast the input across that dimension if it is not
|
||||
# mapped.
|
||||
dims_ = list(dims)
|
||||
args_ = list(args)
|
||||
for input_index, _ in input_output_aliases:
|
||||
dim = dims_[input_index]
|
||||
if dim is batching.not_mapped:
|
||||
dims_[input_index] = 0
|
||||
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
||||
args = tuple(args_)
|
||||
dims = tuple(dims_)
|
||||
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
batched_block_mappings = map(
|
||||
partial(_batch_block_mapping, grid_mapping.grid),
|
||||
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)
|
||||
|
||||
batched_in_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
||||
tuple_insert(x.shape, dim, axis_size),
|
||||
x.dtype)
|
||||
for x, dim in zip(in_shapes, dims))
|
||||
batched_out_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
|
||||
for x in out_shapes)
|
||||
|
||||
batched_grid_mapping = grid_mapping.replace(
|
||||
grid=(axis_size, *grid_mapping.grid),
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
|
||||
out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}",
|
||||
in_shapes=batched_in_shapes,
|
||||
out_shapes=batched_out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
**compiler_params)
|
||||
return out, (0,) * len(out)
|
||||
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: str | None = None):
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree_thunk, False,
|
||||
primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
jaxpr = for_loop._hoist_consts_to_refs(jaxpr)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
if name is None:
|
||||
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
|
||||
return name
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None], out_shape: Any, *,
|
||||
grid_spec: GridSpec | None = None,
|
||||
debug: bool = False,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
input_output_aliases: Dict[int, int] = {},
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
**compiler_params: Any):
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
name = _extract_function_name(f, name)
|
||||
singleton = False
|
||||
if not isinstance(out_shape, (tuple, list)):
|
||||
out_shape = (out_shape,)
|
||||
singleton = True
|
||||
if not isinstance(out_shape, tuple):
|
||||
out_shape = tuple(out_shape)
|
||||
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
||||
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
for x in flat_out_shapes]
|
||||
@jax.jit
|
||||
def wrapped(*args):
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args]
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree,
|
||||
flat_out_shapes, out_tree)
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
||||
jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
|
||||
tuple(jaxpr_flat_avals),
|
||||
primitive_name="pallas_call")
|
||||
which_linear = (False,) * len(flat_args)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
|
||||
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
for a in flat_args),
|
||||
out_shapes=tuple(flat_out_shapes), debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
**compiler_params)
|
||||
out = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
if singleton:
|
||||
return out[0]
|
||||
return out
|
||||
return wrapped
|
411
jax/_src/pallas/primitives.py
Normal file
411
jax/_src/pallas/primitives.py
Normal file
@ -0,0 +1,411 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for pallas-specific JAX primitives and functions."""
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import state
|
||||
from jax._src.util import (safe_map, safe_zip)
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import indexing
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Slice = indexing.Slice
|
||||
NDIndexer = indexing.NDIndexer
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
program_id_p = jax_core.Primitive("program_id")
|
||||
|
||||
def program_id(axis):
|
||||
return program_id_p.bind(axis=axis)
|
||||
|
||||
def program_id_bind(*, axis: int):
|
||||
grid_env = pallas_core.current_grid_env()
|
||||
if grid_env:
|
||||
return grid_env[axis].axis_index
|
||||
return jax_core.Primitive.bind(program_id_p, axis=axis)
|
||||
program_id_p.def_custom_bind(program_id_bind)
|
||||
|
||||
def _program_id_impl(*, axis: int):
|
||||
grid_env = pallas_core.current_grid_env()
|
||||
return grid_env[axis].axis_index
|
||||
program_id_p.def_impl(_program_id_impl)
|
||||
|
||||
mlir.register_lowering(program_id_p, functools.partial(xla.apply_primitive,
|
||||
program_id_p))
|
||||
|
||||
def _program_id_abstract_eval(**_):
|
||||
return jax_core.ShapedArray((), jnp.int32)
|
||||
program_id_p.def_abstract_eval(_program_id_abstract_eval)
|
||||
|
||||
class AtomicOpType(enum.Enum):
|
||||
XCHG = "xchg"
|
||||
ADD = "add"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
XOR = "xor"
|
||||
|
||||
atomic_rmw_p = jax_core.Primitive("atomic_rmw")
|
||||
|
||||
def _atomic_rmw_discharge_rule(in_avals, out_avals, ref, val, *args, args_tree,
|
||||
masked, atomic_type: AtomicOpType):
|
||||
if masked: raise NotImplementedError
|
||||
ref_aval, val_aval, *in_avals = in_avals
|
||||
idx_aval, *_ = tree_util.tree_unflatten(args_tree, in_avals)
|
||||
idx, *_ = tree_util.tree_unflatten(args_tree, args)
|
||||
if atomic_type == AtomicOpType.ADD:
|
||||
monoid = lambda x, y: x + y
|
||||
elif atomic_type == AtomicOpType.MAX:
|
||||
monoid = jnp.maximum
|
||||
elif atomic_type == AtomicOpType.MIN:
|
||||
monoid = jnp.minimum
|
||||
else:
|
||||
raise NotImplementedError(atomic_type)
|
||||
|
||||
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
|
||||
indices = idx.indices
|
||||
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
|
||||
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
||||
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
|
||||
val = val[val_indexer]
|
||||
val = monoid(val, out_ones)
|
||||
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
out = ref[idx.indices]
|
||||
x_new = ref.at[idx.indices].set(monoid(out, val))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return (x_new,) + (None,) * (len(in_avals) + 1), out
|
||||
state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule)
|
||||
|
||||
def _atomic_abstract_eval(ref_aval, val_aval, *all_avals,
|
||||
args_tree, atomic_type: AtomicOpType,
|
||||
**_: Any):
|
||||
if ref_aval.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
|
||||
raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.")
|
||||
if ref_aval.dtype in {jnp.dtype("bool"), jnp.dtype("int8"),
|
||||
jnp.dtype("int16"), jnp.bfloat16}:
|
||||
raise ValueError(f"`atomic_{atomic_type.value}` does not support {ref_aval.dtype}.")
|
||||
return _swap_abstract_eval(ref_aval, val_aval, *all_avals,
|
||||
args_tree=args_tree)
|
||||
atomic_rmw_p.def_effectful_abstract_eval(_atomic_abstract_eval)
|
||||
|
||||
def atomic_rmw(x_ref, idx, val, *, mask: Any | None = None,
|
||||
atomic_type: AtomicOpType):
|
||||
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
|
||||
args = (idx,)
|
||||
if mask is not None:
|
||||
args = (*args, mask)
|
||||
flat_args, args_tree = tree_util.tree_flatten(args)
|
||||
return atomic_rmw_p.bind(x_ref, val, *flat_args, args_tree=args_tree,
|
||||
atomic_type=atomic_type, masked=mask is not None)
|
||||
|
||||
atomic_xchg = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XCHG)
|
||||
atomic_add = functools.partial(atomic_rmw, atomic_type=AtomicOpType.ADD)
|
||||
atomic_max = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MAX)
|
||||
atomic_min = functools.partial(atomic_rmw, atomic_type=AtomicOpType.MIN)
|
||||
atomic_and = functools.partial(atomic_rmw, atomic_type=AtomicOpType.AND)
|
||||
atomic_or = functools.partial(atomic_rmw, atomic_type=AtomicOpType.OR)
|
||||
atomic_xor = functools.partial(atomic_rmw, atomic_type=AtomicOpType.XOR)
|
||||
|
||||
atomic_cas_p = jax_core.Primitive("atomic_cas")
|
||||
|
||||
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
|
||||
if cmp_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Dtypes in cmp/val need to match")
|
||||
if ref_aval.shape != ():
|
||||
raise ValueError("Ref must be scalar.")
|
||||
if cmp_aval.shape != ():
|
||||
raise ValueError("Cmp must be scalar.")
|
||||
if val_aval.shape != ():
|
||||
raise ValueError("Val must be scalar.")
|
||||
if cmp_aval.shape != val_aval.shape:
|
||||
raise ValueError("Dtypes in cmp/val need to match")
|
||||
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)}
|
||||
atomic_cas_p.def_effectful_abstract_eval(_atomic_cas_abstract_eval)
|
||||
|
||||
def atomic_cas(ref, cmp, val):
|
||||
return atomic_cas_p.bind(ref, cmp, val)
|
||||
|
||||
@state_discharge.register_discharge_rule(atomic_cas_p)
|
||||
def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
|
||||
del in_avals, out_avals
|
||||
new_val = jnp.where(ref == cmp, val, ref)
|
||||
return (new_val, None, None), ref
|
||||
|
||||
max_contiguous_p = jax_core.Primitive("max_contiguous")
|
||||
|
||||
max_contiguous_p.def_impl(lambda x, **_: x)
|
||||
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
|
||||
|
||||
def max_contiguous(x, values):
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
return max_contiguous_p.bind(x, values=values)
|
||||
|
||||
def _max_contiguous_abstract_eval(aval, **_):
|
||||
return aval
|
||||
max_contiguous_p.def_abstract_eval(_max_contiguous_abstract_eval)
|
||||
|
||||
multiple_of_p = jax_core.Primitive("multiple_of")
|
||||
|
||||
multiple_of_p.def_impl(lambda x, **_: x)
|
||||
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
|
||||
|
||||
def multiple_of(x, values):
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
return multiple_of_p.bind(x, values=values)
|
||||
|
||||
def _multiple_of_abstract_eval(aval, **_):
|
||||
return aval
|
||||
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)
|
||||
|
||||
load_p = jax_core.Primitive('masked_load')
|
||||
|
||||
def _load_abstract_eval(ref_aval, *all_avals, args_tree,
|
||||
**params: Any):
|
||||
idx_aval, *_ = tree_util.tree_unflatten(args_tree, all_avals)
|
||||
return (jax_core.ShapedArray(idx_aval.get_indexer_shape(), ref_aval.dtype),
|
||||
{state.ReadEffect(0)})
|
||||
load_p.def_effectful_abstract_eval(_load_abstract_eval)
|
||||
|
||||
def _pp_dslice(dim: int, slice: Slice, context):
|
||||
size = pp.text(str(slice.size))
|
||||
if isinstance(slice.start, int):
|
||||
if slice.start == 0:
|
||||
start = pp.text("")
|
||||
else:
|
||||
start = pp.text(str(slice.start))
|
||||
if slice.size == dim:
|
||||
end = pp.text("")
|
||||
else:
|
||||
end = pp.text(str(slice.start + slice.size))
|
||||
else:
|
||||
start = pp.text(jax_core.pp_var(slice.start, context))
|
||||
end = pp.concat([start, pp.text("+"), size])
|
||||
return pp.concat([start, pp.text(":"), end])
|
||||
|
||||
def _pp_idx(ref_aval, idx: NDIndexer, context):
|
||||
docs = [
|
||||
_pp_dslice(d, s, context) if isinstance(s, Slice)
|
||||
else pp.text(jax_core.pp_var(s, context))
|
||||
for s, d in zip(idx.indices, ref_aval.shape)]
|
||||
if not docs:
|
||||
return pp.text("")
|
||||
doc = [docs[0]]
|
||||
for d in docs[1:]:
|
||||
doc.append(pp.text(","))
|
||||
doc.append(d)
|
||||
return pp.concat(doc)
|
||||
|
||||
def _load_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = load x i` as `x[i] <- a`
|
||||
y, = eqn.outvars
|
||||
x, *args = eqn.invars
|
||||
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
|
||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
||||
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return pp.concat([lhs, pp.text(' <- '), state_primitives.pp_ref(pp.concat([
|
||||
pp.text(jax_core.pp_var(x, context)), pp.text('['), idx, pp.text(']')
|
||||
]))])
|
||||
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
||||
|
||||
def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
|
||||
ref_primal, *rest_primals = primals
|
||||
ref_tangent, *rest_tangents = tangents
|
||||
idx_primal, *masked_other_primals = tree_util.tree_unflatten(args_tree, rest_primals)
|
||||
flat_idx_primals = tree_util.tree_leaves(idx_primal)
|
||||
_, *masked_other_tangents = tree_util.tree_unflatten(args_tree, rest_tangents)
|
||||
tangent_args = flat_idx_primals
|
||||
if masked:
|
||||
tangent_args = [*tangent_args, masked_other_primals[0]]
|
||||
if len(masked_other_tangents) == 2:
|
||||
_, other_tangent = masked_other_tangents
|
||||
other_tangent = ad_util.instantiate(other_tangent)
|
||||
tangent_args = [*tangent_args, other_tangent]
|
||||
return (
|
||||
load_p.bind(ref_primal, *rest_primals, args_tree=args_tree, masked=masked, **params),
|
||||
load_p.bind(ref_tangent, *tangent_args, args_tree=args_tree,
|
||||
masked=masked, **params))
|
||||
ad.primitive_jvps[load_p] = _load_jvp
|
||||
|
||||
def _load_discharge_rule(in_avals, out_avals, ref, *args, args_tree,
|
||||
masked, eviction_policy, cache_modifier, is_volatile):
|
||||
idx, *masked_other = tree_util.tree_unflatten(args_tree, args)
|
||||
if all(isinstance(s, Slice) or not s.shape for s in idx.indices):
|
||||
indices = idx.indices
|
||||
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
|
||||
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
||||
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
out = ref[idx.indices]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if masked and len(masked_other) == 2:
|
||||
mask, other = masked_other
|
||||
out = jnp.where(mask, out, other)
|
||||
return (None,) * len(in_avals), out
|
||||
state_discharge.register_discharge_rule(load_p)(_load_discharge_rule)
|
||||
|
||||
swap_p = jax_core.Primitive('masked_swap')
|
||||
|
||||
def _swap_abstract_eval(ref_aval, val_aval, *all_avals, args_tree,
|
||||
**_: Any):
|
||||
idx_aval, *_ = tree_util.tree_unflatten(args_tree, all_avals)
|
||||
expected_output_shape = idx_aval.get_indexer_shape()
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx_aval}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return (jax_core.ShapedArray(expected_output_shape, ref_aval.dtype),
|
||||
{state.WriteEffect(0)})
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
def _swap_pp_rule(eqn, context, settings):
|
||||
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
|
||||
# or:
|
||||
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
||||
y, = eqn.outvars
|
||||
x, val, *args = eqn.invars
|
||||
idx, *masked_other = tree_util.tree_unflatten(eqn.params["args_tree"], args)
|
||||
idx = _pp_idx(eqn.invars[0].aval, idx, context)
|
||||
x_i = pp.concat([pp.text(jax_core.pp_var(x, context)),
|
||||
pp.text('['), idx, pp.text(']')])
|
||||
if isinstance(y, jax_core.DropVar):
|
||||
return pp.concat([state_primitives.pp_ref(
|
||||
x_i), pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
|
||||
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||||
return pp.concat([y, pp.text(', '), state_primitives.pp_ref(x_i),
|
||||
pp.text(' <- '), state_primitives.pp_ref(x_i),
|
||||
pp.text(', '), pp.text(jax_core.pp_var(val, context))])
|
||||
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||||
|
||||
def _swap_jvp(primals, tangents, *, args_tree, masked, **params: Any):
|
||||
ref_primal, val_primal, *rest_primals = primals
|
||||
ref_tangent, val_tangent, *rest_tangents = tangents
|
||||
val_tangent = ad_util.instantiate(val_tangent)
|
||||
idx_primal, *masked_other_primals = tree_util.tree_unflatten(args_tree, rest_primals)
|
||||
flat_idx_primals = tree_util.tree_leaves(idx_primal)
|
||||
_, *masked_other_tangents = tree_util.tree_unflatten(args_tree, rest_tangents)
|
||||
tangent_args = flat_idx_primals
|
||||
if masked:
|
||||
tangent_args = [*tangent_args, masked_other_primals[0]]
|
||||
if len(masked_other_tangents) == 2:
|
||||
_, other_tangent = masked_other_tangents
|
||||
other_tangent = ad_util.instantiate(other_tangent)
|
||||
tangent_args = [*tangent_args, other_tangent]
|
||||
return (
|
||||
swap_p.bind(ref_primal, val_primal, *rest_primals, args_tree=args_tree, masked=masked, **params),
|
||||
swap_p.bind(ref_tangent, val_tangent, *tangent_args, args_tree=args_tree,
|
||||
masked=masked, **params))
|
||||
ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
def _swap_discharge_rule(in_avals, out_avals, ref, val, *args, args_tree,
|
||||
masked, eviction_policy):
|
||||
idx, *_ = tree_util.tree_unflatten(args_tree, args)
|
||||
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
|
||||
indices = idx.indices
|
||||
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
|
||||
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
||||
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
||||
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
|
||||
val = val[val_indexer]
|
||||
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
out = ref[idx.indices]
|
||||
x_new = ref.at[idx.indices].set(val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
||||
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
|
||||
|
||||
|
||||
def load(x_ref, idx, *, mask=None, other=None, cache_modifier="",
|
||||
eviction_policy="", volatile=False):
|
||||
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
|
||||
args = (idx,)
|
||||
if mask is not None:
|
||||
args = (*args, mask)
|
||||
if other is not None:
|
||||
assert mask is not None
|
||||
args = (*args, other)
|
||||
flat_args, args_tree = tree_util.tree_flatten(args)
|
||||
return load_p.bind(x_ref, *flat_args, masked=mask is not None, cache_modifier=cache_modifier,
|
||||
eviction_policy=eviction_policy, is_volatile=volatile,
|
||||
args_tree=args_tree)
|
||||
|
||||
def swap(x_ref, idx, val, *, mask=None, eviction_policy="") -> Any:
|
||||
idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
|
||||
args = (idx,)
|
||||
if mask is not None:
|
||||
args = (*args, mask)
|
||||
flat_args, args_tree = tree_util.tree_flatten(args)
|
||||
return swap_p.bind(x_ref, val, *flat_args, masked=mask is not None,
|
||||
eviction_policy=eviction_policy, args_tree=args_tree)
|
||||
|
||||
def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None:
|
||||
_ = swap(x_ref, idx, val, mask=mask, eviction_policy=eviction_policy)
|
||||
|
||||
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
|
||||
allow_tf32: bool | None = None, precision=None):
|
||||
lhs_contract_dim = 0 if trans_a else 1
|
||||
rhs_contract_dim = 0 if not trans_b else 1
|
||||
if allow_tf32 is not None:
|
||||
if precision is not None:
|
||||
raise ValueError("Only one of allow_tf32 and precision can be specified")
|
||||
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
|
||||
return jax.lax.dot_general(
|
||||
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
|
||||
precision=precision,
|
||||
preferred_element_type=None).astype(jnp.float32)
|
46
jax/_src/pallas/triton/BUILD
Normal file
46
jax/_src/pallas/triton/BUILD
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Package for Triton-specific Pallas extensions
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "triton",
|
||||
srcs = ["__init__.py"],
|
||||
lib_rule = pytype_strict_library,
|
||||
deps = [
|
||||
":lowering",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "lowering",
|
||||
srcs = ["lowering.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
] + py_deps("jax_triton"),
|
||||
)
|
22
jax/_src/pallas/triton/__init__.py
Normal file
22
jax/_src/pallas/triton/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton-specific pallas modules."""
|
||||
|
||||
from jax._src.pallas.triton import lowering
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
|
||||
get_compute_capability = triton_kernel_call_lib.get_compute_capability
|
||||
|
||||
del lowering
|
1728
jax/_src/pallas/triton/lowering.py
Normal file
1728
jax/_src/pallas/triton/lowering.py
Normal file
File diff suppressed because it is too large
Load Diff
49
jax/_src/pallas/utils.py
Normal file
49
jax/_src/pallas/utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pallas utility functions."""
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from jax import lax
|
||||
|
||||
|
||||
def when(condition):
|
||||
def _wrapped(f):
|
||||
if isinstance(condition, bool):
|
||||
if condition:
|
||||
f()
|
||||
else:
|
||||
lax.cond(condition, f, lambda: None)
|
||||
return _wrapped
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def strides_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
size = np.prod(shape)
|
||||
strides = []
|
||||
for s in shape:
|
||||
size = size // s
|
||||
strides.append(int(size))
|
||||
return tuple(strides)
|
||||
|
||||
|
||||
def next_power_of_2(x: int) -> int:
|
||||
if x == 0:
|
||||
return 1
|
||||
return int(2 ** math.ceil(math.log2(x)))
|
@ -15,6 +15,7 @@
|
||||
"""JAX bindings for Mosaic."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections.abc
|
||||
@ -29,6 +30,7 @@ from jax import core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import mhlo
|
||||
from jaxlib.mlir.dialects import stablehlo
|
||||
@ -38,7 +40,7 @@ import numpy as np
|
||||
|
||||
# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
|
||||
if tpu_mosaic is None:
|
||||
raise ValueError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
|
||||
raise ImportError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
|
||||
tpu = tpu_mosaic.tpu
|
||||
apply_vector_layout = tpu_mosaic.apply_vector_layout
|
||||
infer_memref_layout = tpu_mosaic.infer_memref_layout
|
||||
@ -244,7 +246,7 @@ def as_tpu_kernel(
|
||||
module: ir.Module,
|
||||
out_type: Any,
|
||||
*,
|
||||
backend: str = "tpu",
|
||||
backend: str | xla_client.Client = "tpu",
|
||||
) -> Callable[..., Any]:
|
||||
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
|
||||
# We use jax.jit to make sure we hit the fast compilation cache.
|
||||
|
52
jax/experimental/pallas/__init__.py
Normal file
52
jax/experimental/pallas/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for pallas, a JAX extension for custom kernels."""
|
||||
|
||||
from jax._src import pallas
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.indexing import ds
|
||||
from jax._src.pallas.indexing import dslice
|
||||
from jax._src.pallas.indexing import broadcast_to
|
||||
from jax._src.pallas.pallas_call import pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add
|
||||
from jax._src.pallas.primitives import atomic_and
|
||||
from jax._src.pallas.primitives import atomic_cas
|
||||
from jax._src.pallas.primitives import atomic_max
|
||||
from jax._src.pallas.primitives import atomic_min
|
||||
from jax._src.pallas.primitives import atomic_or
|
||||
from jax._src.pallas.primitives import atomic_xchg
|
||||
from jax._src.pallas.primitives import atomic_xor
|
||||
from jax._src.pallas.primitives import dot
|
||||
from jax._src.pallas.primitives import load
|
||||
from jax._src.pallas.primitives import max_contiguous
|
||||
from jax._src.pallas.primitives import multiple_of
|
||||
from jax._src.pallas.primitives import program_id
|
||||
from jax._src.pallas.primitives import store
|
||||
from jax._src.pallas.primitives import swap
|
||||
from jax._src.pallas.utils import cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape
|
||||
from jax._src.pallas.utils import when
|
||||
|
||||
try:
|
||||
from jax.experimental.pallas import gpu # pytype: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
try:
|
||||
from jax.experimental.pallas import tpu # pytype: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
22
jax/experimental/pallas/gpu.py
Normal file
22
jax/experimental/pallas/gpu.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton specific Pallas functions."""
|
||||
try:
|
||||
from jax._src.pallas import triton
|
||||
get_compute_capability = triton.get_compute_capability
|
||||
del triton
|
||||
except ImportError as e:
|
||||
raise ImportError("Cannot import Pallas Triton backend. "
|
||||
"Make sure you've installed jax-triton.") from e
|
13
jax/experimental/pallas/ops/__init__.py
Normal file
13
jax/experimental/pallas/ops/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
364
jax/experimental/pallas/ops/attention.py
Normal file
364
jax/experimental/pallas/ops/attention.py
Normal file
@ -0,0 +1,364 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing fused attention forward and backward pass."""
|
||||
import functools
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
def mha_forward_kernel(
|
||||
q_ref, k_ref, v_ref, # Input arrays
|
||||
o_ref, # Output
|
||||
*residual_refs, # Residual outputs
|
||||
sm_scale: float, causal: bool,
|
||||
block_q: int, block_d: int, block_k: int):
|
||||
seq_len = q_ref.shape[0]
|
||||
start_q = pl.program_id(0)
|
||||
|
||||
# acc is the buffer where we accumulate the output on sram.
|
||||
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
|
||||
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
|
||||
l_i = jnp.zeros(block_q, dtype=jnp.float32)
|
||||
# acc is the buffer where we accumulate the output on sram.
|
||||
acc = jnp.zeros((block_q, block_d), dtype=jnp.float32)
|
||||
|
||||
# Load q: it will stay in L1 throughout. Indices form a matrix because we
|
||||
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
|
||||
# q tile has shape [block_q, block_d], block_d == head_dim.
|
||||
q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)))
|
||||
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
|
||||
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
|
||||
# Here we only loop over blocks of kv to process entire seq_len, the loop over
|
||||
# blocks of q is carried out by the grid.
|
||||
def body(start_k, carry):
|
||||
acc, m_prev, l_prev = carry
|
||||
|
||||
k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None)))
|
||||
qk = jnp.zeros([block_q, block_k], dtype=jnp.float32)
|
||||
qk += pl.dot(q, k.T) # [block_q, block_k]
|
||||
if sm_scale != 1.:
|
||||
qk *= sm_scale # [block_q, block_k]
|
||||
|
||||
if causal:
|
||||
span_q = start_q * block_q + jnp.arange(block_q)
|
||||
span_k = start_k * block_k + jnp.arange(block_k)
|
||||
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
|
||||
# Bring closer to XLA:GPU numerics.
|
||||
qk = qk.astype(q_ref.dtype)
|
||||
qk = qk.astype(jnp.float32)
|
||||
m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev)
|
||||
l_prev *= jnp.exp(m_prev - m_curr)
|
||||
p = jnp.exp(qk - m_curr[:, None])
|
||||
l_curr = jnp.sum(p, axis=1) + l_prev
|
||||
|
||||
l_rcp = 1. / l_curr
|
||||
p = p * l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
p = p.astype(jnp.float16)
|
||||
|
||||
v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d)))
|
||||
acc = acc + pl.dot(p.astype(v.dtype), v)
|
||||
return acc, m_curr, l_curr
|
||||
if causal:
|
||||
upper_bound = lax.div(block_q * start_q, block_k) + 1
|
||||
else:
|
||||
upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
|
||||
acc, m_i, l_i = lax.fori_loop(0, upper_bound, body,
|
||||
(acc, m_i, l_i))
|
||||
|
||||
if residual_refs:
|
||||
l_ref, m_ref = residual_refs
|
||||
pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i)
|
||||
pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i)
|
||||
# Write output to dram.
|
||||
acc = acc.astype(o_ref.dtype)
|
||||
pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale", "causal", "block_q", "block_k",
|
||||
"backward_pass_impl",
|
||||
"num_warps", "num_stages", "grid",
|
||||
"interpret", "debug"])
|
||||
def mha(q, k, v,
|
||||
sm_scale: float = 1.0,
|
||||
causal: bool = False,
|
||||
block_q: int = 128,
|
||||
block_k: int = 128,
|
||||
backward_pass_impl: str = "triton",
|
||||
num_warps: Optional[int] = None,
|
||||
num_stages: int = 2,
|
||||
grid=None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False):
|
||||
del backward_pass_impl
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
block_q = min(block_q, seq_len)
|
||||
block_k = min(block_k, seq_len)
|
||||
# Heuristics.
|
||||
grid_ = grid
|
||||
if grid_ is None:
|
||||
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
|
||||
|
||||
num_warps_ = num_warps
|
||||
if num_warps_ is None:
|
||||
num_warps_ = 4 if head_dim <= 64 else 8
|
||||
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
|
||||
block_q=block_q, block_k=block_k,
|
||||
block_d=head_dim,
|
||||
causal=causal)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=grid_,
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
num_warps=num_warps_,
|
||||
num_stages=num_stages,
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_forward")(q, k, v)
|
||||
|
||||
def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int,
|
||||
block_k: int, backward_pass_impl: str, num_warps: Optional[int],
|
||||
num_stages: int, grid: Any, interpret: bool, debug: bool):
|
||||
del backward_pass_impl
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
block_q = min(block_q, seq_len)
|
||||
block_k = min(block_k, seq_len)
|
||||
# Heuristics.
|
||||
grid_ = grid
|
||||
if grid_ is None:
|
||||
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
|
||||
|
||||
num_warps_ = num_warps
|
||||
if num_warps_ is None:
|
||||
num_warps_ = 4 if head_dim <= 64 else 8
|
||||
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
|
||||
causal=causal, block_q=block_q, block_k=block_k,
|
||||
block_d=head_dim)
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
|
||||
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l
|
||||
dtype=jnp.float32),
|
||||
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m
|
||||
dtype=jnp.float32)
|
||||
]
|
||||
out, l, m = pl.pallas_call(
|
||||
kernel,
|
||||
grid=grid_,
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
],
|
||||
num_warps=num_warps_,
|
||||
num_stages=num_stages,
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_forward")(q, k, v)
|
||||
return out, (q, k, v, out, l, m)
|
||||
|
||||
def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
|
||||
new_dout_ref, delta_ref, *,
|
||||
block_q: int):
|
||||
pid_m = pl.program_id(0)
|
||||
|
||||
off_m = pl.ds(pid_m * block_q, block_q)
|
||||
# load
|
||||
o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
|
||||
do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
|
||||
denom = pl.load(l_ref, (off_m,)).astype(jnp.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = jnp.sum(o * do, axis=1)
|
||||
# write-back
|
||||
pl.store(new_dout_ref, (off_m, slice(None)),
|
||||
do.astype(new_dout_ref.dtype))
|
||||
pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))
|
||||
|
||||
def _preprocess_backward(out, do, l, block_q: int,
|
||||
debug: bool, interpret: bool):
|
||||
batch_size, seq_len, num_heads, head_dim = out.shape
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(do.shape, do.dtype),
|
||||
jax.ShapeDtypeStruct(l.shape, l.dtype),
|
||||
]
|
||||
do_scaled, delta = pl.pallas_call(
|
||||
functools.partial(_preprocess_backward_kernel, block_q=block_q),
|
||||
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
],
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_preprocess_backward")(out, do, l)
|
||||
return do_scaled, delta
|
||||
|
||||
def mha_backward_kernel(
|
||||
# Inputs
|
||||
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
|
||||
l_ref, m_ref, delta_ref, _,
|
||||
# Outputs
|
||||
dq_ref, dk_ref, dv_ref,
|
||||
*, sm_scale: float, causal: bool,
|
||||
block_q: int, block_d: int, block_k: int
|
||||
):
|
||||
del out_ref, l_ref # Not needed
|
||||
seq_len = q_ref.shape[0]
|
||||
|
||||
def outer_loop(start_k, _):
|
||||
|
||||
dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
|
||||
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
|
||||
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
|
||||
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
|
||||
span_k = start_k * block_k + jnp.arange(block_k)
|
||||
|
||||
def inner_loop(start_q, carry):
|
||||
dv, dk = carry
|
||||
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
|
||||
qk = pl.dot(q, k.T)
|
||||
qk = qk.astype(q_ref.dtype)
|
||||
qk = qk.astype(jnp.float32)
|
||||
if sm_scale != 1.0:
|
||||
qk *= sm_scale
|
||||
if causal:
|
||||
span_q = start_q * block_q + jnp.arange(block_q)
|
||||
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
|
||||
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
|
||||
p = jnp.exp(qk - m[:, None])
|
||||
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
|
||||
dv = dv + pl.dot(p.astype(do.dtype).T, do)
|
||||
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
|
||||
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
|
||||
dp = dp + pl.dot(do, v.T)
|
||||
ds = p * dp
|
||||
if sm_scale != 1.0:
|
||||
ds = ds * sm_scale
|
||||
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
|
||||
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
|
||||
slice(None)), eviction_policy="evict_last")
|
||||
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
|
||||
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
|
||||
slice(None)), dq, eviction_policy="evict_last")
|
||||
return dv, dk
|
||||
if causal:
|
||||
lower_bound = lax.div(start_k * block_k, block_q)
|
||||
else:
|
||||
lower_bound = 0
|
||||
dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
|
||||
(dv, dk))
|
||||
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
|
||||
slice(None)), dv.astype(dv_ref.dtype))
|
||||
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
|
||||
slice(None)), dk.astype(dk_ref.dtype))
|
||||
lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
|
||||
|
||||
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
|
||||
backward_pass_impl: str, num_warps: Optional[int],
|
||||
num_stages: int, grid: Any, interpret: bool,
|
||||
debug: bool, res, do):
|
||||
del num_warps, num_stages, grid
|
||||
q, k, v, out, l, m = res
|
||||
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
block_q = min(block_q, seq_len)
|
||||
block_k = min(block_k, seq_len)
|
||||
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
|
||||
|
||||
if backward_pass_impl == "xla":
|
||||
return jax.vjp(functools.partial(mha_reference, sm_scale=sm_scale,
|
||||
causal=causal), q, k, v)[1](do)
|
||||
elif backward_pass_impl == "triton":
|
||||
# We accumulate into dq so we need to initialize it to zeros.
|
||||
dq = jnp.zeros(q.shape, jnp.float32)
|
||||
out_shapes = [
|
||||
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
|
||||
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
||||
jax.ShapeDtypeStruct(v.shape, v.dtype),
|
||||
]
|
||||
|
||||
grid = (batch_size, num_heads)
|
||||
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
|
||||
num_warps = 4
|
||||
dq, dk, dv = pl.pallas_call(
|
||||
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
|
||||
block_k=block_k, sm_scale=sm_scale, causal=causal),
|
||||
grid=grid,
|
||||
out_shape=out_shapes,
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
||||
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
||||
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
],
|
||||
name="mha_backward",
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
|
||||
else:
|
||||
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
|
||||
return dq.astype(q.dtype), dk, dv
|
||||
mha.defvjp(_mha_forward, _mha_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
|
||||
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False):
|
||||
q_seq_len = q.shape[1]
|
||||
kv_seq_len = k.shape[1]
|
||||
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
|
||||
if causal:
|
||||
mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
|
||||
mask = jnp.broadcast_to(mask, logits.shape)
|
||||
logits = jnp.where(mask, logits, float('-inf'))
|
||||
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
||||
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
|
281
jax/experimental/pallas/ops/layer_norm.py
Normal file
281
jax/experimental/pallas/ops/layer_norm.py
Normal file
@ -0,0 +1,281 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing fused layer norm forward and backward pass."""
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
def layer_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
o_ref, mean_ref=None, rstd_ref=None, # Output arrays
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
acc_ref[:] += a
|
||||
mean = for_loop(pl.cdiv(n_col, block_size), mean_body,
|
||||
jnp.zeros(block_size)).sum() / n_col
|
||||
|
||||
def var_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a = jnp.where(mask, a - mean, 0.)
|
||||
acc_ref[:] += a * a
|
||||
var = for_loop(pl.cdiv(n_col, block_size), var_body,
|
||||
jnp.zeros(block_size)).sum() / n_col
|
||||
rstd = 1 / jnp.sqrt(var + eps)
|
||||
if mean_ref is not None:
|
||||
mean_ref[...] = mean.astype(mean_ref.dtype)
|
||||
if rstd_ref is not None:
|
||||
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
|
||||
|
||||
def body(i, _):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask)
|
||||
bias = pl.load(bias_ref, (col_idx,), mask=mask)
|
||||
x = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_first").astype(jnp.float32)
|
||||
out = (x - mean) * rstd * weight + bias
|
||||
pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask)
|
||||
for_loop(pl.cdiv(n_col, block_size), body, ())
|
||||
|
||||
|
||||
def layer_norm_forward(
|
||||
x, weight, bias,
|
||||
num_warps: Optional[int] = None,
|
||||
num_stages: Optional[int] = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
del num_stages
|
||||
del backward_pass_impl
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret, name='ln_forward')
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, mean, rstd = method(x, weight, bias)
|
||||
return out, (x, weight, bias, mean, rstd)
|
||||
|
||||
|
||||
def layer_norm_backward_kernel_dx(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
mean_ref, rstd_ref,
|
||||
# Outputs
|
||||
dx_ref,
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a_hat = (a - mean_ref[...]) * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
mean1_acc_ref, mean2_acc_ref = acc_ref
|
||||
mean1_acc_ref[:] += a_hat * wdout
|
||||
mean2_acc_ref[:] += wdout
|
||||
mean = for_loop(pl.cdiv(n_col, block_size), mean_body,
|
||||
(jnp.zeros(block_size), jnp.zeros(block_size)))
|
||||
mean1, mean2 = mean
|
||||
mean1 = mean1.sum() / n_col
|
||||
mean2 = mean2.sum() / n_col
|
||||
|
||||
def dx_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a_hat = (a - mean_ref[...]) * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...]
|
||||
pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask)
|
||||
for_loop(pl.cdiv(n_col, block_size), dx_body, ())
|
||||
|
||||
|
||||
def layer_norm_backward_kernel_dw_db(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
mean_ref, rstd_ref,
|
||||
# Outputs
|
||||
dw_ref, db_ref,
|
||||
*, eps: float, block_m: int, block_n: int):
|
||||
m, n_col = x_ref.shape
|
||||
j = pl.program_id(0)
|
||||
col_idx = j * block_n + jnp.arange(block_n)
|
||||
col_mask = col_idx < n_col
|
||||
|
||||
def body(i, acc_ref):
|
||||
row_idx = i * block_m + jnp.arange(block_m)
|
||||
row_mask = row_idx < m
|
||||
mask = row_mask[:, None] & col_mask[None, :]
|
||||
a = pl.load(
|
||||
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
dout = pl.load(
|
||||
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
|
||||
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw_acc_ref, db_acc_ref = acc_ref
|
||||
dw_acc_ref[:] += (dout * a_hat).sum(axis=0)
|
||||
db_acc_ref[:] += dout.sum(axis=0)
|
||||
dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n)))
|
||||
pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask)
|
||||
pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask)
|
||||
|
||||
|
||||
def layer_norm_backward(
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
eps: float,
|
||||
backward_pass_impl: str,
|
||||
interpret: bool,
|
||||
res, do):
|
||||
del num_stages
|
||||
x, weight, bias, mean, rstd = res
|
||||
if backward_pass_impl == 'xla':
|
||||
return jax.vjp(layer_norm_reference, x, weight, bias)[1](do)
|
||||
|
||||
*shape_prefix, n = x.shape
|
||||
reshaped_x = x.reshape((-1, n))
|
||||
reshaped_mean = mean.reshape((-1,))
|
||||
reshaped_rstd = rstd.reshape((-1,))
|
||||
reshaped_do = do.reshape((-1, n))
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
# layer_norm_backward_kernel_dx parallel over batch dims
|
||||
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape_dx, debug=False,
|
||||
interpret=interpret, name='ln_backward_dx')
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
dx = dx.reshape((*shape_prefix, n))
|
||||
|
||||
# layer_norm_backward_kernel_dw_db reduce over batch dims
|
||||
# Triton heuristics
|
||||
if n > 10240:
|
||||
block_n = 128
|
||||
block_m = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
block_n = 16
|
||||
block_m = 16
|
||||
num_warps = 8
|
||||
kernel = functools.partial(layer_norm_backward_kernel_dw_db, eps=eps,
|
||||
block_m=block_m, block_n=block_n)
|
||||
out_shape_dwbias = [
|
||||
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=grid_, out_shape=out_shape_dwbias, debug=False,
|
||||
interpret=interpret, name='ln_backward_dw_db')
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
|
||||
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
|
||||
"num_stages", "eps",
|
||||
"backward_pass_impl",
|
||||
"interpret"])
|
||||
def layer_norm(
|
||||
x, weight, bias,
|
||||
num_warps: Optional[int] = None,
|
||||
num_stages: Optional[int] = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["eps"])
|
||||
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
|
||||
def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
|
||||
mean = jnp.mean(x, axis=1)
|
||||
mean2 = jnp.mean(jnp.square(x), axis=1)
|
||||
var = jnp.maximum(0., mean2 - jnp.square(mean))
|
||||
y = x - mean[:, None]
|
||||
mul = lax.rsqrt(var + eps)
|
||||
return y * mul[:, None] * weight[None] + bias[None]
|
259
jax/experimental/pallas/ops/rms_norm.py
Normal file
259
jax/experimental/pallas/ops/rms_norm.py
Normal file
@ -0,0 +1,259 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing rms forward and backward pass."""
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
def rms_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
o_ref, rstd_ref=None, # Output arrays
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def var_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a = jnp.where(mask, a, 0.)
|
||||
acc_ref[:] += a * a
|
||||
var = for_loop(pl.cdiv(n_col, block_size), var_body,
|
||||
jnp.zeros(block_size)).sum() / n_col
|
||||
rstd = 1 / jnp.sqrt(var + eps)
|
||||
if rstd_ref is not None:
|
||||
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
|
||||
|
||||
def body(i, _):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask)
|
||||
bias = pl.load(bias_ref, (col_idx,), mask=mask)
|
||||
x = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_first").astype(jnp.float32)
|
||||
out = x * rstd * weight + bias
|
||||
pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask)
|
||||
for_loop(pl.cdiv(n_col, block_size), body, ())
|
||||
|
||||
|
||||
def rms_norm_forward(
|
||||
x, weight, bias,
|
||||
num_warps: Optional[int] = None,
|
||||
num_stages: Optional[int] = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
del num_stages
|
||||
del backward_pass_impl
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret, name='rms_forward')
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, rstd = method(x, weight, bias)
|
||||
return out, (x, weight, bias, rstd)
|
||||
|
||||
|
||||
def rms_norm_backward_kernel_dx(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
rstd_ref,
|
||||
# Outputs
|
||||
dx_ref,
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, c1_acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a_hat = a * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
c1_acc_ref[:] += a_hat * wdout
|
||||
c1 = for_loop(pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size))
|
||||
c1 = c1.sum() / n_col
|
||||
|
||||
def dx_body(i, acc_ref):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
|
||||
eviction_policy="evict_last").astype(jnp.float32)
|
||||
a_hat = a * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * c1)) * rstd_ref[...]
|
||||
pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask)
|
||||
for_loop(pl.cdiv(n_col, block_size), dx_body, ())
|
||||
|
||||
|
||||
def rms_norm_backward_kernel_dw_db(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
rstd_ref,
|
||||
# Outputs
|
||||
dw_ref, db_ref,
|
||||
*, eps: float, block_m: int, block_n: int):
|
||||
m, n_col = x_ref.shape
|
||||
j = pl.program_id(0)
|
||||
col_idx = j * block_n + jnp.arange(block_n)
|
||||
col_mask = col_idx < n_col
|
||||
|
||||
def body(i, acc_ref):
|
||||
row_idx = i * block_m + jnp.arange(block_m)
|
||||
row_mask = row_idx < m
|
||||
mask = row_mask[:, None] & col_mask[None, :]
|
||||
a = pl.load(
|
||||
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
dout = pl.load(
|
||||
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
|
||||
a_hat = a * rstd[:, None]
|
||||
dw_acc_ref, db_acc_ref = acc_ref
|
||||
dw_acc_ref[:] += (dout * a_hat).sum(axis=0)
|
||||
db_acc_ref[:] += dout.sum(axis=0)
|
||||
dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n)))
|
||||
pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask)
|
||||
pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask)
|
||||
|
||||
|
||||
def rms_norm_backward(
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
eps: float,
|
||||
backward_pass_impl: str,
|
||||
interpret: bool,
|
||||
res, do):
|
||||
del num_stages
|
||||
x, weight, bias, rstd = res
|
||||
if backward_pass_impl == 'xla':
|
||||
return jax.vjp(rms_norm_reference, x, weight, bias)[1](do)
|
||||
|
||||
*shape_prefix, n = x.shape
|
||||
reshaped_x = x.reshape((-1, n))
|
||||
reshaped_rstd = rstd.reshape((-1,))
|
||||
reshaped_do = do.reshape((-1, n))
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
# rms_norm_backward_kernel_dx parallel over batch dims
|
||||
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape_dx, debug=False,
|
||||
interpret=interpret, name='ln_backward_dx')
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
dx = dx.reshape((*shape_prefix, n))
|
||||
|
||||
# rms_norm_backward_kernel_dw_db reduce over batch dims
|
||||
# Triton heuristics
|
||||
if n > 10240:
|
||||
block_n = 128
|
||||
block_m = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
block_n = 16
|
||||
block_m = 16
|
||||
num_warps = 8
|
||||
kernel = functools.partial(rms_norm_backward_kernel_dw_db, eps=eps,
|
||||
block_m=block_m, block_n=block_n)
|
||||
out_shape_dwbias = [
|
||||
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=grid_, out_shape=out_shape_dwbias, debug=False,
|
||||
interpret=interpret, name='ln_backward_dw_db')
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
|
||||
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
|
||||
"num_stages", "eps",
|
||||
"backward_pass_impl",
|
||||
"interpret"])
|
||||
def rms_norm(
|
||||
x, weight, bias,
|
||||
num_warps: Optional[int] = None,
|
||||
num_stages: Optional[int] = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["eps"])
|
||||
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
|
||||
def rms_norm_reference(x, weight, bias, *, eps: float = 1e-5):
|
||||
var = jnp.mean(jnp.square(x), axis=1)
|
||||
mul = lax.rsqrt(var + eps)
|
||||
return x * mul[:, None] * weight[None] + bias[None]
|
86
jax/experimental/pallas/ops/softmax.py
Normal file
86
jax/experimental/pallas/ops/softmax.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pallas softmax kernel."""
|
||||
import functools
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
|
||||
def _vmappable_softmax_kernel(
|
||||
# inputs
|
||||
input_ref,
|
||||
# outputs
|
||||
probs_ref,
|
||||
*,
|
||||
# block information
|
||||
# It is assumed that block_row >= row_len
|
||||
block_row: int,
|
||||
):
|
||||
row_len = input_ref.shape[-1]
|
||||
|
||||
mask = jnp.arange(block_row) < row_len
|
||||
row = pl.load(
|
||||
input_ref, (pl.dslice(0, block_row),), mask=mask, other=-float("inf")
|
||||
)
|
||||
|
||||
row_max = jnp.max(row, axis=0)
|
||||
numerator = jnp.exp((row - row_max).astype(jnp.float32))
|
||||
denominator = jnp.sum(numerator, axis=0)
|
||||
|
||||
pl.store(
|
||||
probs_ref, (pl.dslice(0, block_row),),
|
||||
(numerator / denominator).astype(probs_ref.dtype),
|
||||
mask=mask
|
||||
)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
|
||||
"debug"])
|
||||
def softmax(
|
||||
x: jax.Array, *, axis: int = -1, num_warps: int = 4,
|
||||
interpret: bool = False, debug: bool = False
|
||||
) -> jax.Array:
|
||||
"""Computes the softmax of the input array along the specified axis.
|
||||
|
||||
Args:
|
||||
x: input array
|
||||
axis: the axis along which to perform the computation
|
||||
num_warps: the number of warps to use for executing the Triton kernel
|
||||
interpret: whether to interpret the kernel using pallas
|
||||
debug: whether to use pallas in debug mode
|
||||
|
||||
Returns:
|
||||
The result of the softmax operation over the specified axis of x.
|
||||
"""
|
||||
axis = axis if axis >= 0 else len(x.shape) + axis
|
||||
if axis != len(x.shape) - 1:
|
||||
raise NotImplementedError(
|
||||
"reductions along non-trailing dimension unsupported")
|
||||
|
||||
row_len = x.shape[-1]
|
||||
|
||||
block_row = pl.next_power_of_2(row_len)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
|
||||
|
||||
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
|
||||
f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
|
||||
out_shape=out_shape, debug=debug, interpret=interpret)
|
||||
|
||||
for _ in range(len(x.shape) - 1):
|
||||
f = jax.vmap(f)
|
||||
|
||||
return f(x)
|
22
jax/experimental/pallas/tpu.py
Normal file
22
jax/experimental/pallas/tpu.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Mosaic specific Pallas functions."""
|
||||
from jax._src.pallas.mosaic import CMEM
|
||||
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic import SMEM
|
||||
from jax._src.pallas.mosaic import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic import VMEM
|
||||
from jax._src.pallas.mosaic import repeat
|
||||
from jax._src.pallas.mosaic import trace
|
@ -37,6 +37,9 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags
|
||||
|
||||
jax_internal_packages = []
|
||||
mosaic_internal_users = []
|
||||
pallas_gpu_internal_users = []
|
||||
pallas_tpu_internal_users = []
|
||||
|
||||
jax_test_util_visibility = []
|
||||
loops_visibility = []
|
||||
|
||||
|
60
tests/pallas/BUILD
Normal file
60
tests/pallas/BUILD
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "pallas_test",
|
||||
srcs = [
|
||||
"pallas_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_x32",
|
||||
"gpu_a100_x32",
|
||||
],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//third_party/py/jax:pallas",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "indexing_test",
|
||||
srcs = [
|
||||
"indexing_test.py",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/py/jax:pallas",
|
||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||
)
|
168
tests/pallas/indexing_test.py
Normal file
168
tests/pallas/indexing_test.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas indexing logic and abstractions."""
|
||||
from typing import Union
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import hypothesis as hp
|
||||
import hypothesis.extra.numpy as hnp
|
||||
import hypothesis.strategies as hps
|
||||
import jax
|
||||
from jax._src import util
|
||||
from jax._src.pallas import indexing
|
||||
import numpy as np
|
||||
|
||||
Slice = indexing.Slice
|
||||
NDIndexer = indexing.NDIndexer
|
||||
ds = indexing.ds
|
||||
|
||||
|
||||
def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
|
||||
return hps.integers(min_value=np.iinfo(np.int32).min, max_value=dim - 1)
|
||||
|
||||
|
||||
@hps.composite
|
||||
def slice_indexer_strategy(draw, dim) -> Union[Slice, slice]:
|
||||
start = draw(int_indexer_strategy(dim))
|
||||
size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max))
|
||||
return draw(
|
||||
hps.one_of(
|
||||
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@hps.composite
|
||||
def array_indexer_strategy(draw, shape) -> jax.Array:
|
||||
unbcast = [draw(hps.booleans()) for _ in shape]
|
||||
shape = tuple(1 if unb else s for unb, s in zip(unbcast, shape))
|
||||
return draw(hnp.arrays(dtype=np.dtype("int32"), shape=shape))
|
||||
|
||||
|
||||
@hps.composite
|
||||
def indexer_strategy(draw, dim, int_indexer_shape
|
||||
) -> Union[int, Slice, jax.Array]:
|
||||
return draw(hps.one_of(
|
||||
int_indexer_strategy(dim),
|
||||
slice_indexer_strategy(dim),
|
||||
array_indexer_strategy(int_indexer_shape),
|
||||
))
|
||||
|
||||
|
||||
@hps.composite
|
||||
def nd_indexer_strategy(draw, shape) -> NDIndexer:
|
||||
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
|
||||
int_indexer_shape = draw(hnp.array_shapes())
|
||||
indices = [draw(indexer_strategy(dim, int_indexer_shape)) for dim
|
||||
in shape[:num_indices]]
|
||||
return NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
|
||||
class IndexerTest(parameterized.TestCase):
|
||||
|
||||
def test_simple_ndindexer(self):
|
||||
indices = (0, 0)
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), ())
|
||||
|
||||
def test_invalid_ndindexer(self):
|
||||
indices = (0, 0, 0)
|
||||
shape = (5, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_ndindexer_with_padding(self):
|
||||
indices = ()
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), shape)
|
||||
|
||||
def test_ndindexer_with_slices(self):
|
||||
indices = (slice(2, 3), slice(4, 7))
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (1, 3))
|
||||
|
||||
def test_ndindexer_with_arrays(self):
|
||||
indices = (np.arange(10), np.arange(10))
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (10,))
|
||||
|
||||
indices = (np.ones((10, 20)), np.ones((10, 20)))
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
|
||||
|
||||
def test_ndindexer_with_arrays_and_broadcasting(self):
|
||||
indices = (np.arange(10)[None], np.arange(20)[:, None])
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (20, 10))
|
||||
|
||||
indices = (np.arange(10)[:, None], np.arange(20)[None, :])
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
|
||||
|
||||
def test_indexer_with_all_types(self):
|
||||
indices = (0, slice(10), np.arange(5))
|
||||
shape = (2, 3, 4)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 10))
|
||||
|
||||
indices = (0, slice(4, 10), np.arange(5))
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 6))
|
||||
|
||||
indices = (0, 5, np.arange(5))
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
||||
|
||||
indices = (ds(2, 3), np.arange(5)[:, None], np.arange(4)[None])
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 3))
|
||||
|
||||
@hp.given(hps.data())
|
||||
def test_ndindexer(self, data):
|
||||
shape = data.draw(hnp.array_shapes())
|
||||
indexer = data.draw(nd_indexer_strategy(shape))
|
||||
is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices]
|
||||
rest_indexers, int_indexers = util.partition_list(
|
||||
is_int_indexer, indexer.indices
|
||||
)
|
||||
if int_indexers:
|
||||
expected_int_indexer_shape = int_indexers[0].shape
|
||||
else:
|
||||
expected_int_indexer_shape = ()
|
||||
self.assertTupleEqual(
|
||||
indexer.int_indexer_shape, expected_int_indexer_shape
|
||||
)
|
||||
for idx in rest_indexers:
|
||||
self.assertIsInstance(idx, (np.ndarray, Slice))
|
||||
if isinstance(idx, np.ndarray):
|
||||
self.assertTupleEqual(idx.shape, ())
|
||||
self.assertEqual(idx.dtype, np.dtype("int32"))
|
||||
rest_shape = tuple(
|
||||
r.size for r in rest_indexers if not isinstance(r, np.ndarray)
|
||||
)
|
||||
self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape),
|
||||
indexer.get_indexer_shape())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
1598
tests/pallas/pallas_test.py
Normal file
1598
tests/pallas/pallas_test.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user